kNN(K-Nearest Neighbor)最鄰近規則分類
K最近鄰分類演算法
方法的思路:如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於這一類別,則該樣本也屬於這個類別。KNN演算法中,所選擇的鄰居都是已經正確分類的物件。該方法在定類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分類樣本所屬的類別。KNN方法雖然從原理上也依賴於極限定理,但在類別決策時,只與極少量的相鄰樣本有關。由於KNN方法主要靠周圍有限的鄰近的樣本,而不是靠判別類域的方法來確定所屬類別的,因此對於類域的交叉或重疊較多的待分樣本集來說,KNN方法較其他方法更為適合。
KNN演算法不僅可以用於分類,還可以用於迴歸。通過找出一個樣本的k個最近鄰居,將這些鄰居的屬性的平均值賦給該樣本,就可以得到該樣本的屬性。更有用的方法是將不同距離的鄰居對該樣本產生的影響給予不同的權值(weight),如權值與距離成正比(組合函式)。
該演算法在分類時有個主要的不足是,當樣本不平衡時,如一個類的樣本容量很大,而其他類樣本容量很小時,
有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本佔多數。 該演算法只計算“最近的”鄰居樣本,某
一類的樣本數量很大,那麼或者這類樣本並不接近目標樣本,或者這類樣本很靠近目標樣本。無論怎樣,數量並不能
影響執行結果。可以採用權值的方法(和該樣本距離小的鄰居權值大)來改進。
該方法的另一個不足之處是計算量較大,因為對每一個待分類的文字都要計算它到全體已知樣本的距離,才能求得它的K個最近鄰點。目前常用的
解決方法是事先對已知樣本點進行剪輯,事先去除對分類作用不大的樣本。該演算法比較適用於樣本容量比較大的類域
的自動分類,而那些樣本容量較小的類域採用這種演算法比較容易產生誤分
簡單來說,K-NN可以看成:有那麼一堆你已經知道分類的資料,然後當一個新資料進入的時候,就開始跟訓練資料
裡的每個點求距離,然後挑離這個訓練資料最近的K個點看看這幾個點屬於什麼型別,然後用少數服從多數的原則,
給新資料歸類。
演算法步驟:
step.1---初始化距離為最大值
step.2---計算未知樣本和每個訓練樣本的距離dist
step.3---得到目前K個最臨近樣本中的最大距離maxdist
step.4---如果dist小於maxdist,則將該訓練樣本作為K-最近鄰樣本
step.5---重複步驟2
step.6---統計K-最近鄰樣本中每個類標號出現的次數
step.7---選擇出現頻率最大的類標號作為未知樣本的類標號
function target=KNN(in,out,test,k)
% in: training samples data,n*d matrix
% out: training samples' class label,n*1
% test: testing data
% target: class label given by knn
% k: the number of neighbors
ClassLabel=unique(out);
c=length(ClassLabel);
n=size(in,1);
% target=zeros(size(test,1),1);
dist=zeros(size(in,1),1);
for j=1:size(test,1)
cnt=zeros(c,1);
for i=1:n
dist(i)=norm(in(i,:)-test(j,:));
end
[d,index]=sort(dist);
for i=1:k
ind=find(ClassLabel==out(index(i)));
cnt(ind)=cnt(ind)+1;
end
[m,ind]=max(cnt);
target(j)=ClassLabel(ind);
end
R語言的實現程式碼如下
library(class)
data(iris)
names(iris)
m1<-knn.cv(iris[,1:4],iris[,5],k=3,prob=TRUE)
attributes(.Last.value)
library(MASS)
m2<-lda(iris[,1:4],iris[,5]) 與判別分析進行比較
b<-data.frame(Sepal.Length=6,Sepal.Width=4,Petal.Length=5,Petal.Width=6)
p1<-predict(m2,b,type="class")
C++ 實現 :
// KNN.cpp K-最近鄰分類演算法
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <string.h>
#include <iostream>
#include <math.h>
#include <fstream>
using namespace std;
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// 巨集定義
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
#define ATTR_NUM 4 //屬性數目
#define MAX_SIZE_OF_TRAINING_SET 1000 //訓練資料集的最大大小
#define MAX_SIZE_OF_TEST_SET 100 //測試資料集的最大大小
#define MAX_VALUE 10000.0 //屬性最大值
#define K 7
//結構體
struct dataVector {
int ID; //ID號
char classLabel[15]; //分類標號
double attributes[ATTR_NUM]; //屬性
};
struct distanceStruct {
int ID; //ID號
double distance; //距離
char classLabel[15]; //分類標號
};
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// 全域性變數
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
struct dataVector gTrainingSet[MAX_SIZE_OF_TRAINING_SET]; //訓練資料集
struct dataVector gTestSet[MAX_SIZE_OF_TEST_SET]; //測試資料集
struct distanceStruct gNearestDistance[K]; //K個最近鄰距離
int curTrainingSetSize=0; //訓練資料集的大小
int curTestSetSize=0; //測試資料集的大小
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// 求 vector1=(x1,x2,...,xn)和vector2=(y1,y2,...,yn)的歐幾里德距離
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
double Distance(struct dataVector vector1,struct dataVector vector2)
{
double dist,sum=0.0;
for(int i=0;i<ATTR_NUM;i++)
{
sum+=(vector1.attributes[i]-vector2.attributes[i])*(vector1.attributes[i]-vector2.attributes[i]);
}
dist=sqrt(sum);
return dist;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// 得到gNearestDistance中的最大距離,返回下標
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
int GetMaxDistance()
{
int maxNo=0;
for(int i=1;i<K;i++)
{
if(gNearestDistance[i].distance>gNearestDistance[maxNo].distance) maxNo = i;
}
return maxNo;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// 對未知樣本Sample分類
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
char* Classify(struct dataVector Sample)
{
double dist=0;
int maxid=0,freq[K],i,tmpfreq=1;;
char *curClassLable=gNearestDistance[0].classLabel;
memset(freq,1,sizeof(freq));
//step.1---初始化距離為最大值
for(i=0;i<K;i++)
{
gNearestDistance[i].distance=MAX_VALUE;
}
//step.2---計算K-最近鄰距離
for(i=0;i<curTrainingSetSize;i++)
{
//step.2.1---計算未知樣本和每個訓練樣本的距離
dist=Distance(gTrainingSet[i],Sample);
//step.2.2---得到gNearestDistance中的最大距離
maxid=GetMaxDistance();
//step.2.3---如果距離小於gNearestDistance中的最大距離,則將該樣本作為K-最近鄰樣本
if(dist<gNearestDistance[maxid].distance)
{
gNearestDistance[maxid].ID=gTrainingSet[i].ID;
gNearestDistance[maxid].distance=dist;
strcpy(gNearestDistance[maxid].classLabel,gTrainingSet[i].classLabel);
}
}
//step.3---統計每個類出現的次數
for(i=0;i<K;i++)
{
for(int j=0;j<K;j++)
{
if((i!=j)&&(strcmp(gNearestDistance[i].classLabel,gNearestDistance[j].classLabel)==0))
{
freq[i]+=1;
}
}
}
//step.4---選擇出現頻率最大的類標號
for(i=0;i<K;i++)
{
if(freq[i]>tmpfreq)
{
tmpfreq=freq[i];
curClassLable=gNearestDistance[i].classLabel;
}
}
return curClassLable;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// 主函式
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
void main()
{
char c;
char *classLabel="";
int i,j, rowNo=0,TruePositive=0,FalsePositive=0;
ifstream filein("iris.data");
FILE *fp;
if(filein.fail()){cout<<"Can't open data.txt"<<endl; return;}
//step.1---讀檔案
while(!filein.eof())
{
rowNo++;//第一組資料rowNo=1
if(curTrainingSetSize>=MAX_SIZE_OF_TRAINING_SET)
{
cout<<"The training set has "<<MAX_SIZE_OF_TRAINING_SET<<" examples!"<<endl<<endl;
break ;
}
//rowNo%3!=0的100組資料作為訓練資料集
if(rowNo%3!=0)
{
gTrainingSet[curTrainingSetSize].ID=rowNo;
for(int i = 0;i < ATTR_NUM;i++)
{
filein>>gTrainingSet[curTrainingSetSize].attributes[i];
filein>>c;
}
filein>>gTrainingSet[curTrainingSetSize].classLabel;
curTrainingSetSize++;
}
//剩下rowNo%3==0的50組做測試資料集
else if(rowNo%3==0)
{
gTestSet[curTestSetSize].ID=rowNo;
for(int i = 0;i < ATTR_NUM;i++)
{
filein>>gTestSet[curTestSetSize].attributes[i];
filein>>c;
}
filein>>gTestSet[curTestSetSize].classLabel;
curTestSetSize++;
}
}
filein.close();
//step.2---KNN演算法進行分類,並將結果寫到檔案iris_OutPut.txt
fp=fopen("iris_OutPut.txt","w+t");
//用KNN演算法進行分類
fprintf(fp,"************************************程式說明***************************************\n");
fprintf(fp,"** 採用KNN演算法對iris.data分類。為了操作方便,對各組資料新增rowNo屬性,第一組rowNo=1!\n");
fprintf(fp,"** 共有150組資料,選擇rowNo模3不等於0的100組作為訓練資料集,剩下的50組做測試資料集\n");
fprintf(fp,"***********************************************************************************\n\n");
fprintf(fp,"************************************實驗結果***************************************\n\n");
for(i=0;i<curTestSetSize;i++)
{
fprintf(fp,"************************************第%d組資料**************************************\n",i+1);
classLabel =Classify(gTestSet[i]);
if(strcmp(classLabel,gTestSet[i].classLabel)==0)//相等時,分類正確
{
TruePositive++;
}
cout<<"rowNo: ";
cout<<gTestSet[i].ID<<" \t";
cout<<"KNN分類結果: ";
cout<<classLabel<<"(正確類標號: ";
cout<<gTestSet[i].classLabel<<")\n";
fprintf(fp,"rowNo: %3d \t KNN分類結果: %s ( 正確類標號: %s )\n",gTestSet[i].ID,classLabel,gTestSet[i].classLabel);
if(strcmp(classLabel,gTestSet[i].classLabel)!=0)//不等時,分類錯誤
{
// cout<<" ***分類錯誤***\n";
fprintf(fp," ***分類錯誤***\n");
}
fprintf(fp,"%d-最臨近資料:\n",K);
for(j=0;j<K;j++)
{
// cout<<gNearestDistance[j].ID<<"\t"<<gNearestDistance[j].distance<<"\t"<<gNearestDistance[j].classLabel[15]<<endl;
fprintf(fp,"rowNo: %3d \t Distance: %f \tClassLable: %s\n",gNearestDistance[j].ID,gNearestDistance[j].distance,gNearestDistance[j].classLabel);
}
fprintf(fp,"\n");
}
FalsePositive=curTestSetSize-TruePositive;
fprintf(fp,"***********************************結果分析**************************************\n",i);
fprintf(fp,"TP(True positive): %d\nFP(False positive): %d\naccuracy: %f\n",TruePositive,FalsePositive,double(TruePositive)/(curTestSetSize-1));
fclose(fp);
return;
}
以上內容為參考網上有關資料;加以總結;