最近鄰分類器及MATLAB實現
應用背景:在前面一些影象處理相關的文章中,已經說到影象的特徵提取,在選擇好一些主要特徵之後,那麼我們用這些特徵做什麼用呢,我們的主要目的是利用這些特徵對影象進行分類。接下來的問題是怎麼分類,這裡介紹最近鄰分類,它是一種最簡單的分類方法。
基本思想:最近鄰分類,顧名思義,距鄰居最近,則與鄰居同類。也就是說,一個待分類的單個樣本A,放入已分好類的多個樣本群Q中,從Q中選擇k個A的鄰居,通過計算A與鄰居之間的某種關係後得出A與這k個鄰居最相似,那麼就把A分為這k個鄰居中出現次數最多的類,因此最近鄰分類也稱k最近鄰分類(k nearest neighbor, KNN)。這種分類方法基本分類3步:1、找待分類樣本與已分類樣本之間的關係,這裡指計算它們之間的距離;2、找距離最近的k
數學原理:在特徵空間中,把每個類的所有樣本的平均值表示為該類,則第i類樣本的均值為:
(1)
其中,Ni為第i類樣本的樣本數目,Wi為第i類樣本集合,W為總類別數目。
樣本之間的距離取歐氏距離,當對一個未知模式 x 進行分類時,需要分別計算 x 與各個類的歐氏距離,如下式所示
其中,|| x-mi ||=((x-mi)^T(x-mi))^1/2,表示歐幾里得範數,即向量的模。
MATLAB實現:
演算法虛擬碼如下
initialize Di(x)=max
for i=1:w
Di(x)=|| x-mi ||
end
find k neighbors of x
find max_Di(x) in k neighbors of x
for i=1:w
if Di(x)<max_Di(x)
i is the neighbors of x
end
count the number of classes in the k nearest neighbors
MATLAB程式碼如下
二維平面兩類分類問題:
clear;
k=9; % 最近鄰居的數目
num_po=100;
x11=rand(num_po,1);
x12=rand(num_po,1);
x1=[x11 x12];
y1=ones(num_po,1);
num_ne=100;
x21=rand(num_ne,1)+1;
x22=rand(num_ne,1);
x2=[x21 x22];
y2=-1*ones(num_ne,1);
x=[x1;x2];
y=[y1;y2];
ClassLabel=unique(y);
num_t=20; % 測試樣本
test1=rand(num_t,1)+0.5;
test2=rand(num_t,1);
test=[test1 test2];
for num=1:num_t
for i=1:(num_po+num_ne)
dis(i)=norm(test(num,:)-x(i,:));
end
[dis_s,index]=sort(dis);
cnt=zeros(1,2);
for j=1:k
ind=find(ClassLabel==y(index(j)));
cnt(ind)=cnt(ind)+1;
end
[m,ind]=max(cnt);
y_test(num)=ClassLabel(ind);% 測試樣本的標記
end
for i=1:(num_po+num_ne) % 訓練樣本圖示
if y(i)>0
plot(x(i,1),x(i,2),'r+');
hold on
else
plot(x(i,1),x(i,2),'b.');
hold on
end
end
for i=1:num_t % 測試樣樣本圖示
if y_test(i)>0
plot(test(i,1),test(i,2),'g+');
title('K-最近鄰分類器');
hold on
else
plot(test(i,1),test(i,2),'y.');
hold on
end
end
執行結果如下
對於二維平面兩類分類問題,每個訓練樣本都有一個距離必須度量,資料計算量過大,耗費大量時間。
下面的例子使用樣本資料的均值來代表類
%鳶尾花屬植物資料集
%花萼長度 花萼寬度 花瓣長度 花瓣寬度 屬種
yuanwei_data=[5.1,3.5,1.4,0.2,1
4.9,3.0,1.4,0.2,1
4.7,3.2,1.3,0.2,1
4.6,3.1,1.5,0.2,1
5.0,3.6,1.4,0.2,1
5.4,3.9,1.7,0.4,1
4.6,3.4,1.4,0.3,1
5.0,3.4,1.5,0.2,1
4.4,2.9,1.4,0.2,1
4.9,3.1,1.5,0.1,1
5.4,3.7,1.5,0.2,1
4.8,3.4,1.6,0.2,1
4.8,3.0,1.4,0.1,1
4.3,3.0,1.1,0.1,1
5.8,4.0,1.2,0.2,1
5.7,4.4,1.5,0.4,1
5.4,3.9,1.3,0.4,1
5.1,3.5,1.4,0.3,1
5.7,3.8,1.7,0.3,1
5.1,3.8,1.5,0.3,1
5.4,3.4,1.7,0.2,1
5.1,3.7,1.5,0.4,1
4.6,3.6,1.0,0.2,1
5.1,3.3,1.7,0.5,1
4.8,3.4,1.9,0.2,1
5.0,3.0,1.6,0.2,1
5.0,3.4,1.6,0.4,1
5.2,3.5,1.5,0.2,1
5.2,3.4,1.4,0.2,1
4.7,3.2,1.6,0.2,1
4.8,3.1,1.6,0.2,1
5.4,3.4,1.5,0.4,1
5.2,4.1,1.5,0.1,1
5.5,4.2,1.4,0.2,1
4.9,3.1,1.5,0.2,1
5.0,3.2,1.2,0.2,1
5.5,3.5,1.3,0.2,1
4.9,3.6,1.4,0.1,1
4.4,3.0,1.3,0.2,1
5.1,3.4,1.5,0.2,1
5.0,3.5,1.3,0.3,1
4.5,2.3,1.3,0.3,1
4.4,3.2,1.3,0.2,1
5.0,3.5,1.6,0.6,1
5.1,3.8,1.9,0.4,1
4.8,3.0,1.4,0.3,1
5.1,3.8,1.6,0.2,1
4.6,3.2,1.4,0.2,1
5.3,3.7,1.5,0.2,1
5.0,3.3,1.4,0.2,1
7.0,3.2,4.7,1.4,2
6.4,3.2,4.5,1.5,2
6.9,3.1,4.9,1.5,2
5.5,2.3,4.0,1.3,2
6.5,2.8,4.6,1.5,2
5.7,2.8,4.5,1.3,2
6.3,3.3,4.7,1.6,2
4.9,2.4,3.3,1.0,2
6.6,2.9,4.6,1.3,2
5.2,2.7,3.9,1.4,2
5.0,2.0,3.5,1.0,2
5.9,3.0,4.2,1.5,2
6.0,2.2,4.0,1.0,2
6.1,2.9,4.7,1.4,2
5.6,2.9,3.6,1.3,2
6.7,3.1,4.4,1.4,2
5.6,3.0,4.5,1.5,2
5.8,2.7,4.1,1.0,2
6.2,2.2,4.5,1.5,2
5.6,2.5,3.9,1.1,2
5.9,3.2,4.8,1.8,2
6.1,2.8,4.0,1.3,2
6.3,2.5,4.9,1.5,2
6.1,2.8,4.7,1.2,2
6.4,2.9,4.3,1.3,2
6.6,3.0,4.4,1.4,2
6.8,2.8,4.8,1.4,2
6.7,3.0,5.0,1.7,2
6.0,2.9,4.5,1.5,2
5.7,2.6,3.5,1.0,2
5.5,2.4,3.8,1.1,2
5.5,2.4,3.7,1.0,2
5.8,2.7,3.9,1.2,2
6.0,2.7,5.1,1.6,2
5.4,3.0,4.5,1.5,2
6.0,3.4,4.5,1.6,2
6.7,3.1,4.7,1.5,2
6.3,2.3,4.4,1.3,2
5.6,3.0,4.1,1.3,2
5.5,2.5,4.0,1.3,2
5.5,2.6,4.4,1.2,2
6.1,3.0,4.6,1.4,2
5.8,2.6,4.0,1.2,2
5.0,2.3,3.3,1.0,2
5.6,2.7,4.2,1.3,2
5.7,3.0,4.2,1.2,2
5.7,2.9,4.2,1.3,2
6.2,2.9,4.3,1.3,2
5.1,2.5,3.0,1.1,2
5.7,2.8,4.1,1.3,2
6.3,3.3,6.0,2.5,3
5.8,2.7,5.1,1.9,3
7.1,3.0,5.9,2.1,3
6.3,2.9,5.6,1.8,3
6.5,3.0,5.8,2.2,3
7.6,3.0,6.6,2.1,3
4.9,2.5,4.5,1.7,3
7.3,2.9,6.3,1.8,3
6.7,2.5,5.8,1.8,3
7.2,3.6,6.1,2.5,3
6.5,3.2,5.1,2.0,3
6.4,2.7,5.3,1.9,3
6.8,3.0,5.5,2.1,3
5.7,2.5,5.0,2.0,3
5.8,2.8,5.1,2.4,3
6.4,3.2,5.3,2.3,3
6.5,3.0,5.5,1.8,3
7.7,3.8,6.7,2.2,3
7.7,2.6,6.9,2.3,3
6.0,2.2,5.0,1.5,3
6.9,3.2,5.7,2.3,3
5.6,2.8,4.9,2.0,3
7.7,2.8,6.7,2.0,3
6.3,2.7,4.9,1.8,3
6.7,3.3,5.7,2.1,3
7.2,3.2,6.0,1.8,3
6.2,2.8,4.8,1.8,3
6.1,3.0,4.9,1.8,3
6.4,2.8,5.6,2.1,3
7.2,3.0,5.8,1.6,3
7.4,2.8,6.1,1.9,3
7.9,3.8,6.4,2.0,3
6.4,2.8,5.6,2.2,3
6.3,2.8,5.1,1.5,3
6.1,2.6,5.6,1.4,3
7.7,3.0,6.1,2.3,3
6.3,3.4,5.6,2.4,3
6.4,3.1,5.5,1.8,3
6.0,3.0,4.8,1.8,3
6.9,3.1,5.4,2.1,3
6.7,3.1,5.6,2.4,3
6.9,3.1,5.1,2.3,3
5.8,2.7,5.1,1.9,3
6.8,3.2,5.9,2.3,3
6.7,3.3,5.7,2.5,3
6.7,3.0,5.2,2.3,3
6.3,2.5,5.0,1.9,3
6.5,3.0,5.2,2.0,3
6.2,3.4,5.4,2.3,3
5.9,3.0,5.1,1.8,3];
% 每類的前40個樣本的均值用於代表該類,後10個作為獨立的測試樣本
m1 = mean(yuanwei_data(1:40, 1:4) ); %第1類的前40個樣本的平均向量
m2 = mean( yuanwei_data(51:90, 1:4) ); %第2類的前40個樣本的平均向量
m3 = mean( yuanwei_data(101:140, 1:4) ); %第3類的前40個樣本的平均向量
% 測試樣本集
Test = [yuanwei_data(41:50, :); yuanwei_data(91:100, :); yuanwei_data(141:150, :)];
% 測試樣本集對應的類別標籤
classLabel =Test (:,5 )';
% 利用最近鄰分類器分類測試樣本
class = zeros(1, 30); %類標籤
for ii = 1:size(Test, 1)%待分類樣本有3個鄰居,且鄰居分別為不同的類的樣本均值
d(1) = norm(Test(ii, 1:4) - m1); %與第1類的距離
d(2) = norm(Test(ii, 1:4) - m2); %與第2類的距離
d(3) = norm(Test(ii, 1:4) - m3); %與第3類的距離
[minVal class(ii)] = min(d); %計算最小距離並將距離樣本最短的類賦給類標籤陣列 class
end
% 測試最近鄰分類器的識別率
nErr = sum(class ~= classLabel);
rate = 1 - nErr / length(class);
classLabel
class
strOut = ['識別率為', num2str(rate*100), '%']
程式執行結果如下圖
從執行結果看,只有第3類的第3個測試樣本分類錯誤,被分為第2類。