人工智慧(AI)之KNN的基本實現
阿新 • • 發佈:2019-02-02
資料集下載地址點我下載
本文主要介紹KNN的實現思想:
- KNN的主要思想就是:通過計算訓練集與測試集之間的距離(歐氏距離、餘弦距離、曼哈頓距離等),然後取出最相似的前N個數據對測試集進行預測
- 通過測試之後發現,就本次的資料集而言,把餘弦距離以及歐氏距離進行加權來確定預測值結果較好,但僅僅是對於本次的訓練資料而言
- KNN當中也還有很多細節可以去優化的,比如說對資料集進行一定的歸一化,而歸一化的方法也是很多的,具體怎麼取,也是要看當前的資料集,找到適合的才是最好的
- 總之對於預測,找好模型才是最重要的,框架確定之後,再來討論具體的優化會更有效果
#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdlib>
#include <sstream>
#include <string.h>
#include <set>
#include <cmath>
#include <iterator>
#include <queue>
#include <map>
using namespace std;
#define ANGER 0
#define DISGUST 1
#define FEAR 2
#define JOY 3
#define SAD 4
#define SURPRISE 5
char c[300];
priority_queue<double,vector<double>,greater<double> >q;
map<double,int>map1; //從小到大
map<double,int, greater<double> >map2; //從大到小double> >兩者空格不可少
const string Str1 = "train", Str2 = "test";
set<string > sets;
bool vector_old[2000][4000];
double vector2[2000][4000];
double proba[9][2000];
double newproba[9][2000];
double dis_save[2000];
double K;
int num1=0;
void readanger()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/anger_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一個單詞不用
double d;
ss >> d;
proba[ANGER][i++] = d;
}
in.close();
}
void readdisgust()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/disgust_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一個單詞不用
double d;
ss >> d;
proba[DISGUST][i++] = d;
}
in.close();
}
void readfear()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/fear_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一個單詞不用
double d;
ss >> d;
proba[FEAR][i++] = d;
}
in.close();
}
void readjoy()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/joy_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一個單詞不用
double d;
ss >> d;
proba[JOY][i++] = d;
}
in.close();
}
void readsad()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/sad_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一個單詞不用
double d;
ss >> d;
proba[SAD][i++] = d;
}
in.close();
}
void readsurprise()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/gold_train/surprise_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一個單詞不用
double d;
ss >> d;
proba[SURPRISE][i++] = d;
}
in.close();
}
void get_proba()
{
readanger();
readdisgust();
readfear();
readsad();
readjoy();
readsurprise();
}
void get_word()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Dataset_words.txt");
ofstream out("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/anger.txt");
string str;
int i = 0;
if(in&&out)
{
while(getline(in,str))
{
if(i==0)
{
i++;
continue;
}
else
{
int j = 0;
stringstream ss;
ss << str;
while(!ss.eof())
{
{
if(j==0)
{
j++;
ss >> str;
str = " ";
sets.insert(str);
}
//cout << str <<endl;
else
{
ss >> str;
sets.insert(str);
}
}
}
}
}
}else{
cerr<<"open in or out file error"<<endl;
}
for(set<string>::iterator it = sets.begin();it != sets.end();it++)
{
if(*it != " ")
{
out << *it << endl;
//cout << *it << endl;
}
}
in.close();
out.close();
}
void clear_stopwords()
{
fstream in;
in.open("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Foxstoplist (1).txt");
ofstream out("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Foxstoplistout.txt");
string str;
if(in)
{
while(getline(in,str))
{
stringstream ss;
ss << str;
while(!ss.eof())
{
ss >> str;
out << str <<endl;
for(set<string>::iterator it = sets.begin();it != sets.end();)
{
if(*it == str)
{
sets.erase(it);
break;
}
else
{
it++;
}
}
}
}
}
in.close();
out.close();
}
void vector_out()
{
ifstream in("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/Dataset_words.txt");
ofstream out("G:/桌面文件/學/大三上學期/第二學期/人工智慧/實驗/Lab 2 實驗材料/vector.txt");
string str;
int i = 0;
int row_num = 0;
while(in&&out)
{
while(getline(in,str))
{
if(i==0)
{
i++;
continue;
}
else
{
int j = 0;
stringstream ss;
ss << str;
while(!ss.eof())
{
int lin_num = 0;
if(j==0)
{
j++;
ss >> str;
}
else
{
ss >> str;
for(set<string>::iterator it=sets.begin(); it != sets.end() ; it++)
{
if(*it == str)
{
vector_old[row_num][lin_num] = true;
}
lin_num++;
}
}
}
}
row_num++;
}
}
string wenben = "文字編號 ";
out << wenben;
for(set<string>::iterator it= sets.begin(); it != sets.end(); it++)
{
out << *it << " ";
}
in.close();
out.close();
}
void compute_dis(double K)
{
for (int i = 0; i < 1246; i++){
double sum = 0;
for (int j = 0; j < sets.size(); j++){
if (vector_old[i][j])
{
sum++;
}
}
for (int j = 0; j < sets.size(); j++){
vector2[i][j] = vector_old[i][j]*1.0/sum;
//out << vector2[i][j] << " ";
}
//out <<endl;
}
for(int mood_n = 0 ; mood_n < 6 ; mood_n++)
{
for(int i = 0 ; i < 1000 ; i++)
{
int dis_num=0;
double pro_sum = 0;
double dis;
int pos;
double max_dis = 0;
double min_dis = 10000;
map<double,int>map1;
map<double,int, greater<double> >map2;
for(int j = 0 ; j < 246 ; j++)
{
dis = 0;
double angle = 0;
double xy_sum=0;
double xx=0;
double yy=0;
double save_angle[2000]={0};
for(int k = 0 ; k < sets.size() ; k++)
{
xy_sum+=vector_old[i+246][k]*vector_old[j][k];
xx+=vector_old[i+246][k]*vector_old[i+246][k];
yy+=vector_old[j][k]*vector_old[j][k];
//dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]);
}
dis_save[j] = xx + yy - 2*xy_sum;
angle = xy_sum/(sqrt(xx)*sqrt(yy));
//angle = angle*(1/sqrt(dis_save[j]));
angle = 0.8*angle + 0.2*dis_save[j];
map2.insert(make_pair(angle,j));
/*
for(int k = 0 ; k < sets.size() ; k++)
{
dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]);
}
dis = sqrt(dis);
dis_sum+=dis;
map1.insert(make_pair(dis,j));
*/
}
cout << "i:" << i <<endl;
/*
for(map<double,int>::iterator it1 = map1.begin();it1!=map1.end();it1++)
{
double temp = it1->first;
temp = temp/dis_sum;
map1.insert(make_pair(temp,it1->second));
}
*/
int K_i = 1;
double dis_sum = 0;
for(map<double,int>::iterator it = map2.begin(); it != map2.end(); it++)
{
if(K_i>K)
{
break;
}
else
{
K_i++;
pro_sum += proba[mood_n][it->second];
//dis_sum+=(1/(dis_save[it->second]*dis_save[it->second]));
/*
for(map<double,int>::iterator it1 = map1.begin();it1!=map1.end();it1++)
{
if(it->second == it1->second)
{
pro_sum = pro_sum + 0.6*
break;
}
}
*/
}
}
newproba[mood_n][i] = pro_sum*1.0/K;
}
}
cout << "happy" <<endl;
}
void print()
{
for(int i = 0 ; i < 6 ; i++)
{
ofstream f;
switch(i)
{
case ANGER: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/anger_predict.txt"); break;
case DISGUST: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/disgust_predict.txt"); break;
case FEAR: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/fear_predict.txt"); break;
case JOY: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/joy_predict.txt"); break;
case SAD: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/sad_predict.txt"); break;
case SURPRISE: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/surprise_predict.txt"); break;
}
for(int j = 0 ; j < 1000 ; j++)
{
f << newproba[i][j] <<endl;
//cout << newproba[i][j] <<endl;
}
f.close();
}
}
int main()
{
cout << "請輸入k:" <<endl;
cin >> K;
get_word();
cout << 0 <<endl;
clear_stopwords();
cout << 1 <<endl;
get_proba();
cout << 2 <<endl;
vector_out();
cout << 3 <<endl;
compute_dis(K);
cout << 4 <<endl;
print();
cout << 5 <<endl;
cout << sets.size() <<endl;
return 0;
}