1. 程式人生 > >人工智慧(AI)之KNN的基本實現

人工智慧(AI)之KNN的基本實現

資料集下載地址點我下載
本文主要介紹KNN的實現思想:

  1. KNN的主要思想就是:通過計算訓練集與測試集之間的距離(歐氏距離、餘弦距離、曼哈頓距離等),然後取出最相似的前N個數據對測試集進行預測
  2. 通過測試之後發現,就本次的資料集而言,把餘弦距離以及歐氏距離進行加權來確定預測值結果較好,但僅僅是對於本次的訓練資料而言
  3. KNN當中也還有很多細節可以去優化的,比如說對資料集進行一定的歸一化,而歸一化的方法也是很多的,具體怎麼取,也是要看當前的資料集,找到適合的才是最好的
  4. 總之對於預測,找好模型才是最重要的,框架確定之後,再來討論具體的優化會更有效果
#include <iostream>
#include <fstream> #include <cstring> #include <cstdlib> #include <sstream> #include <string.h> #include <set> #include <cmath> #include <iterator> #include <queue> #include <map> using namespace std; #define ANGER 0 #define DISGUST 1 #define FEAR 2
#define JOY 3 #define SAD 4 #define SURPRISE 5 char c[300]; priority_queue<double,vector<double>,greater<double> >q; map<double,int>map1; //從小到大 map<double,int, greater<double> >map2; //從大到小double> >兩者空格不可少 const string Str1 = "train", Str2 = "test"; set<string
>
sets; bool vector_old[2000][4000]; double vector2[2000][4000]; double proba[9][2000]; double newproba[9][2000]; double dis_save[2000]; double K; int num1=0; void readanger() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/anger_train.txt"); int i = 0; while (in && i < 246){ memset(c, 0, sizeof(c)); in.getline(c, 300); string s; s.append(c, 300); stringstream ss(s); ss >> s; // 第一個單詞不用 double d; ss >> d; proba[ANGER][i++] = d; } in.close(); } void readdisgust() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/disgust_train.txt"); int i = 0; while (in && i < 246){ memset(c, 0, sizeof(c)); in.getline(c, 300); string s; s.append(c, 300); stringstream ss(s); ss >> s; // 第一個單詞不用 double d; ss >> d; proba[DISGUST][i++] = d; } in.close(); } void readfear() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/fear_train.txt"); int i = 0; while (in && i < 246){ memset(c, 0, sizeof(c)); in.getline(c, 300); string s; s.append(c, 300); stringstream ss(s); ss >> s; // 第一個單詞不用 double d; ss >> d; proba[FEAR][i++] = d; } in.close(); } void readjoy() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/joy_train.txt"); int i = 0; while (in && i < 246){ memset(c, 0, sizeof(c)); in.getline(c, 300); string s; s.append(c, 300); stringstream ss(s); ss >> s; // 第一個單詞不用 double d; ss >> d; proba[JOY][i++] = d; } in.close(); } void readsad() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/sad_train.txt"); int i = 0; while (in && i < 246){ memset(c, 0, sizeof(c)); in.getline(c, 300); string s; s.append(c, 300); stringstream ss(s); ss >> s; // 第一個單詞不用 double d; ss >> d; proba[SAD][i++] = d; } in.close(); } void readsurprise() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/surprise_train.txt"); int i = 0; while (in && i < 246){ memset(c, 0, sizeof(c)); in.getline(c, 300); string s; s.append(c, 300); stringstream ss(s); ss >> s; // 第一個單詞不用 double d; ss >> d; proba[SURPRISE][i++] = d; } in.close(); } void get_proba() { readanger(); readdisgust(); readfear(); readsad(); readjoy(); readsurprise(); } void get_word() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Dataset_words.txt"); ofstream out("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/anger.txt"); string str; int i = 0; if(in&&out) { while(getline(in,str)) { if(i==0) { i++; continue; } else { int j = 0; stringstream ss; ss << str; while(!ss.eof()) { { if(j==0) { j++; ss >> str; str = " "; sets.insert(str); } //cout << str <<endl; else { ss >> str; sets.insert(str); } } } } } }else{ cerr<<"open in or out file error"<<endl; } for(set<string>::iterator it = sets.begin();it != sets.end();it++) { if(*it != " ") { out << *it << endl; //cout << *it << endl; } } in.close(); out.close(); } void clear_stopwords() { fstream in; in.open("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Foxstoplist (1).txt"); ofstream out("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Foxstoplistout.txt"); string str; if(in) { while(getline(in,str)) { stringstream ss; ss << str; while(!ss.eof()) { ss >> str; out << str <<endl; for(set<string>::iterator it = sets.begin();it != sets.end();) { if(*it == str) { sets.erase(it); break; } else { it++; } } } } } in.close(); out.close(); } void vector_out() { ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Dataset_words.txt"); ofstream out("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/vector.txt"); string str; int i = 0; int row_num = 0; while(in&&out) { while(getline(in,str)) { if(i==0) { i++; continue; } else { int j = 0; stringstream ss; ss << str; while(!ss.eof()) { int lin_num = 0; if(j==0) { j++; ss >> str; } else { ss >> str; for(set<string>::iterator it=sets.begin(); it != sets.end() ; it++) { if(*it == str) { vector_old[row_num][lin_num] = true; } lin_num++; } } } } row_num++; } } string wenben = "文字編號 "; out << wenben; for(set<string>::iterator it= sets.begin(); it != sets.end(); it++) { out << *it << " "; } in.close(); out.close(); } void compute_dis(double K) { for (int i = 0; i < 1246; i++){ double sum = 0; for (int j = 0; j < sets.size(); j++){ if (vector_old[i][j]) { sum++; } } for (int j = 0; j < sets.size(); j++){ vector2[i][j] = vector_old[i][j]*1.0/sum; //out << vector2[i][j] << " "; } //out <<endl; } for(int mood_n = 0 ; mood_n < 6 ; mood_n++) { for(int i = 0 ; i < 1000 ; i++) { int dis_num=0; double pro_sum = 0; double dis; int pos; double max_dis = 0; double min_dis = 10000; map<double,int>map1; map<double,int, greater<double> >map2; for(int j = 0 ; j < 246 ; j++) { dis = 0; double angle = 0; double xy_sum=0; double xx=0; double yy=0; double save_angle[2000]={0}; for(int k = 0 ; k < sets.size() ; k++) { xy_sum+=vector_old[i+246][k]*vector_old[j][k]; xx+=vector_old[i+246][k]*vector_old[i+246][k]; yy+=vector_old[j][k]*vector_old[j][k]; //dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]); } dis_save[j] = xx + yy - 2*xy_sum; angle = xy_sum/(sqrt(xx)*sqrt(yy)); //angle = angle*(1/sqrt(dis_save[j])); angle = 0.8*angle + 0.2*dis_save[j]; map2.insert(make_pair(angle,j)); /* for(int k = 0 ; k < sets.size() ; k++) { dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]); } dis = sqrt(dis); dis_sum+=dis; map1.insert(make_pair(dis,j)); */ } cout << "i:" << i <<endl; /* for(map<double,int>::iterator it1 = map1.begin();it1!=map1.end();it1++) { double temp = it1->first; temp = temp/dis_sum; map1.insert(make_pair(temp,it1->second)); } */ int K_i = 1; double dis_sum = 0; for(map<double,int>::iterator it = map2.begin(); it != map2.end(); it++) { if(K_i>K) { break; } else { K_i++; pro_sum += proba[mood_n][it->second]; //dis_sum+=(1/(dis_save[it->second]*dis_save[it->second])); /* for(map<double,int>::iterator it1 = map1.begin();it1!=map1.end();it1++) { if(it->second == it1->second) { pro_sum = pro_sum + 0.6* break; } } */ } } newproba[mood_n][i] = pro_sum*1.0/K; } } cout << "happy" <<endl; } void print() { for(int i = 0 ; i < 6 ; i++) { ofstream f; switch(i) { case ANGER: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/anger_predict.txt"); break; case DISGUST: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/disgust_predict.txt"); break; case FEAR: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/fear_predict.txt"); break; case JOY: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/joy_predict.txt"); break; case SAD: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/sad_predict.txt"); break; case SURPRISE: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/surprise_predict.txt"); break; } for(int j = 0 ; j < 1000 ; j++) { f << newproba[i][j] <<endl; //cout << newproba[i][j] <<endl; } f.close(); } } int main() { cout << "請輸入k:" <<endl; cin >> K; get_word(); cout << 0 <<endl; clear_stopwords(); cout << 1 <<endl; get_proba(); cout << 2 <<endl; vector_out(); cout << 3 <<endl; compute_dis(K); cout << 4 <<endl; print(); cout << 5 <<endl; cout << sets.size() <<endl; return 0; }