1. 程式人生 > >最近鄰分類器及MATLAB實現

最近鄰分類器及MATLAB實現

應用背景:在前面一些影象處理相關的文章中,已經說到影象的特徵提取,在選擇好一些主要特徵之後,那麼我們用這些特徵做什麼用呢,我們的主要目的是利用這些特徵對影象進行分類。接下來的問題是怎麼分類,這裡介紹最近鄰分類,它是一種最簡單的分類方法。

基本思想:最近鄰分類,顧名思義,距鄰居最近,則與鄰居同類。也就是說,一個待分類的單個樣本A,放入已分好類的多個樣本群Q中,從Q中選擇kA的鄰居,通過計算A與鄰居之間的某種關係後得出A與這k個鄰居最相似,那麼就把A分為這k個鄰居中出現次數最多的類,因此最近鄰分類也稱k最近鄰分類(k  nearest  neighbor,  KNN)。這種分類方法基本分類3步:1、找待分類樣本與已分類樣本之間的關係,這裡指計算它們之間的距離;2、找距離最近的k

個已分類的樣本;3、分類,從這k個樣本中找出出現次數最多的類,那麼待分類樣本屬於該類。

數學原理:在特徵空間中,把每個類的所有樣本的平均值表示為該類,則第i類樣本的均值為:

                                                                                                          (1)

其中,Ni為第i類樣本的樣本數目,Wi為第i類樣本集合,W為總類別數目。

樣本之間的距離取歐氏距離,當對一個未知模式 x 進行分類時,需要分別計算 x 與各個類的歐氏距離,如下式所示

                                                                                 

                    (2)

其中,|| x-mi ||=((x-mi)^T(x-mi))^1/2,表示歐幾里得範數,即向量的模。

MATLAB實現

演算法虛擬碼如下

        initialize  Di(x)=max

        for   i=1:w

                  Di(x)=|| x-mi ||

        end

        find k neighbors of  x

        find  max_Di(x) in k neighbors of  x      

        for   i=1:w

                  if  Di(x)<max_Di(x)

                  i is the neighbors of  x

        end

        count the number of classes in the k nearest neighbors

MATLAB程式碼如下

二維平面兩類分類問題:

clear;
k=9;       % 最近鄰居的數目
num_po=100;            
x11=rand(num_po,1);
x12=rand(num_po,1);
x1=[x11 x12];  
y1=ones(num_po,1);     
num_ne=100;            
x21=rand(num_ne,1)+1;
x22=rand(num_ne,1);    
x2=[x21 x22];
y2=-1*ones(num_ne,1);   
x=[x1;x2];
y=[y1;y2];
ClassLabel=unique(y);
num_t=20;                    % 測試樣本
test1=rand(num_t,1)+0.5;
test2=rand(num_t,1);
test=[test1 test2];        
for num=1:num_t
    for i=1:(num_po+num_ne)
        dis(i)=norm(test(num,:)-x(i,:));
    end
    [dis_s,index]=sort(dis);
    cnt=zeros(1,2);
    for j=1:k
        ind=find(ClassLabel==y(index(j)));
        cnt(ind)=cnt(ind)+1;
    end
    [m,ind]=max(cnt);
    y_test(num)=ClassLabel(ind);% 測試樣本的標記
end
for i=1:(num_po+num_ne)         % 訓練樣本圖示
    if y(i)>0
        plot(x(i,1),x(i,2),'r+');
        hold on
    else
        plot(x(i,1),x(i,2),'b.');
        hold on
    end
end
for i=1:num_t                   % 測試樣樣本圖示
    if y_test(i)>0
        plot(test(i,1),test(i,2),'g+');
        title('K-最近鄰分類器');
        hold on
    else
        plot(test(i,1),test(i,2),'y.');
        hold on
    end
end

執行結果如下


對於二維平面兩類分類問題,每個訓練樣本都有一個距離必須度量,資料計算量過大,耗費大量時間。

下面的例子使用樣本資料的均值來代表類

%鳶尾花屬植物資料集
%花萼長度 花萼寬度 花瓣長度 花瓣寬度 屬種
yuanwei_data=[5.1,3.5,1.4,0.2,1
4.9,3.0,1.4,0.2,1
4.7,3.2,1.3,0.2,1
4.6,3.1,1.5,0.2,1
5.0,3.6,1.4,0.2,1
5.4,3.9,1.7,0.4,1
4.6,3.4,1.4,0.3,1
5.0,3.4,1.5,0.2,1
4.4,2.9,1.4,0.2,1
4.9,3.1,1.5,0.1,1
5.4,3.7,1.5,0.2,1
4.8,3.4,1.6,0.2,1
4.8,3.0,1.4,0.1,1
4.3,3.0,1.1,0.1,1
5.8,4.0,1.2,0.2,1
5.7,4.4,1.5,0.4,1
5.4,3.9,1.3,0.4,1
5.1,3.5,1.4,0.3,1
5.7,3.8,1.7,0.3,1
5.1,3.8,1.5,0.3,1
5.4,3.4,1.7,0.2,1
5.1,3.7,1.5,0.4,1
4.6,3.6,1.0,0.2,1
5.1,3.3,1.7,0.5,1
4.8,3.4,1.9,0.2,1
5.0,3.0,1.6,0.2,1
5.0,3.4,1.6,0.4,1
5.2,3.5,1.5,0.2,1
5.2,3.4,1.4,0.2,1
4.7,3.2,1.6,0.2,1
4.8,3.1,1.6,0.2,1
5.4,3.4,1.5,0.4,1
5.2,4.1,1.5,0.1,1
5.5,4.2,1.4,0.2,1
4.9,3.1,1.5,0.2,1
5.0,3.2,1.2,0.2,1
5.5,3.5,1.3,0.2,1
4.9,3.6,1.4,0.1,1
4.4,3.0,1.3,0.2,1
5.1,3.4,1.5,0.2,1
5.0,3.5,1.3,0.3,1
4.5,2.3,1.3,0.3,1
4.4,3.2,1.3,0.2,1
5.0,3.5,1.6,0.6,1
5.1,3.8,1.9,0.4,1
4.8,3.0,1.4,0.3,1
5.1,3.8,1.6,0.2,1
4.6,3.2,1.4,0.2,1
5.3,3.7,1.5,0.2,1
5.0,3.3,1.4,0.2,1
7.0,3.2,4.7,1.4,2
6.4,3.2,4.5,1.5,2
6.9,3.1,4.9,1.5,2
5.5,2.3,4.0,1.3,2
6.5,2.8,4.6,1.5,2
5.7,2.8,4.5,1.3,2
6.3,3.3,4.7,1.6,2
4.9,2.4,3.3,1.0,2
6.6,2.9,4.6,1.3,2
5.2,2.7,3.9,1.4,2
5.0,2.0,3.5,1.0,2
5.9,3.0,4.2,1.5,2
6.0,2.2,4.0,1.0,2
6.1,2.9,4.7,1.4,2
5.6,2.9,3.6,1.3,2
6.7,3.1,4.4,1.4,2
5.6,3.0,4.5,1.5,2
5.8,2.7,4.1,1.0,2
6.2,2.2,4.5,1.5,2
5.6,2.5,3.9,1.1,2
5.9,3.2,4.8,1.8,2
6.1,2.8,4.0,1.3,2
6.3,2.5,4.9,1.5,2
6.1,2.8,4.7,1.2,2
6.4,2.9,4.3,1.3,2
6.6,3.0,4.4,1.4,2
6.8,2.8,4.8,1.4,2
6.7,3.0,5.0,1.7,2
6.0,2.9,4.5,1.5,2
5.7,2.6,3.5,1.0,2
5.5,2.4,3.8,1.1,2
5.5,2.4,3.7,1.0,2
5.8,2.7,3.9,1.2,2
6.0,2.7,5.1,1.6,2
5.4,3.0,4.5,1.5,2
6.0,3.4,4.5,1.6,2
6.7,3.1,4.7,1.5,2
6.3,2.3,4.4,1.3,2
5.6,3.0,4.1,1.3,2
5.5,2.5,4.0,1.3,2
5.5,2.6,4.4,1.2,2
6.1,3.0,4.6,1.4,2
5.8,2.6,4.0,1.2,2
5.0,2.3,3.3,1.0,2
5.6,2.7,4.2,1.3,2
5.7,3.0,4.2,1.2,2
5.7,2.9,4.2,1.3,2
6.2,2.9,4.3,1.3,2
5.1,2.5,3.0,1.1,2
5.7,2.8,4.1,1.3,2
6.3,3.3,6.0,2.5,3
5.8,2.7,5.1,1.9,3
7.1,3.0,5.9,2.1,3
6.3,2.9,5.6,1.8,3
6.5,3.0,5.8,2.2,3
7.6,3.0,6.6,2.1,3
4.9,2.5,4.5,1.7,3
7.3,2.9,6.3,1.8,3
6.7,2.5,5.8,1.8,3
7.2,3.6,6.1,2.5,3
6.5,3.2,5.1,2.0,3
6.4,2.7,5.3,1.9,3
6.8,3.0,5.5,2.1,3
5.7,2.5,5.0,2.0,3
5.8,2.8,5.1,2.4,3
6.4,3.2,5.3,2.3,3
6.5,3.0,5.5,1.8,3
7.7,3.8,6.7,2.2,3
7.7,2.6,6.9,2.3,3
6.0,2.2,5.0,1.5,3
6.9,3.2,5.7,2.3,3
5.6,2.8,4.9,2.0,3
7.7,2.8,6.7,2.0,3
6.3,2.7,4.9,1.8,3
6.7,3.3,5.7,2.1,3
7.2,3.2,6.0,1.8,3
6.2,2.8,4.8,1.8,3
6.1,3.0,4.9,1.8,3
6.4,2.8,5.6,2.1,3
7.2,3.0,5.8,1.6,3
7.4,2.8,6.1,1.9,3
7.9,3.8,6.4,2.0,3
6.4,2.8,5.6,2.2,3
6.3,2.8,5.1,1.5,3
6.1,2.6,5.6,1.4,3
7.7,3.0,6.1,2.3,3
6.3,3.4,5.6,2.4,3
6.4,3.1,5.5,1.8,3
6.0,3.0,4.8,1.8,3
6.9,3.1,5.4,2.1,3
6.7,3.1,5.6,2.4,3
6.9,3.1,5.1,2.3,3
5.8,2.7,5.1,1.9,3
6.8,3.2,5.9,2.3,3
6.7,3.3,5.7,2.5,3
6.7,3.0,5.2,2.3,3
6.3,2.5,5.0,1.9,3
6.5,3.0,5.2,2.0,3
6.2,3.4,5.4,2.3,3
5.9,3.0,5.1,1.8,3];
% 每類的前40個樣本的均值用於代表該類,後10個作為獨立的測試樣本
m1 = mean(yuanwei_data(1:40, 1:4) ); %第1類的前40個樣本的平均向量
m2 = mean( yuanwei_data(51:90, 1:4) ); %第2類的前40個樣本的平均向量
m3 = mean( yuanwei_data(101:140, 1:4) ); %第3類的前40個樣本的平均向量
% 測試樣本集
Test = [yuanwei_data(41:50, :); yuanwei_data(91:100, :); yuanwei_data(141:150, :)];
% 測試樣本集對應的類別標籤
classLabel =Test (:,5 )';
% 利用最近鄰分類器分類測試樣本
class = zeros(1, 30); %類標籤
for ii = 1:size(Test, 1)%待分類樣本有3個鄰居,且鄰居分別為不同的類的樣本均值
   d(1) = norm(Test(ii, 1:4) - m1); %與第1類的距離
   d(2) = norm(Test(ii, 1:4) - m2); %與第2類的距離
   d(3) = norm(Test(ii, 1:4) - m3); %與第3類的距離  
   [minVal class(ii)] = min(d); %計算最小距離並將距離樣本最短的類賦給類標籤陣列 class
end
% 測試最近鄰分類器的識別率
nErr = sum(class ~= classLabel);
rate = 1 - nErr / length(class);
classLabel
class
strOut = ['識別率為', num2str(rate*100), '%']

程式執行結果如下圖


從執行結果看,只有第3類的第3個測試樣本分類錯誤,被分為第2類。