1. 程式人生 > >K-means演算法和KNN演算法

K-means演算法和KNN演算法

github: 智慧演算法的課件和參考資料以及實驗程式碼

 

K-means是最為常用的聚類演算法,該演算法是將相似的樣本歸置在一起的一種無監督演算法。採用距離作為相似性的評價指標,即認為兩個物件的距離越近,其相似度就越大。

演算法主要步驟可描述如下:

 

  1. 隨機產生K個初始聚類中心。
  2. 計算測試點到聚類中心的距離,選擇距離最近的聚類中心將測試點歸類。
  3. 更新每類的聚類中心。
  4. 重複步驟2、3迭代更新,直至聚類中心不再改變,或者新的聚類中心與前一步聚類中心的距離小於某個值

KNN演算法是最簡單的分類演算法,如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。演算法主要步驟可描述如下:

  1、計算已知類別資料集中的點與當前點之間的距離;

2、按照距離遞增依次排序;

3、選取與當前點距離最小的k個點

4、確定k個點在所在類別的出現頻率

5、返回k個點出現頻率最高的類別作為當前點的預測分類

 

下面直接給出程式碼(matlab實現):

k-means演算法:

clear all;close all; clc;
% 第一組資料
mul=[0,0]; % 均值
S1=[.1 0;0 .1]; % 協方差
data1=mvnrnd(mul, S1, 100); % 產生高斯分佈資料
% 第二組資料
mu2=[1.25 1.25];
S2=[.1 0;0 .1];
data2=mvnrnd(mu2,S2,100);
% 第三組資料
mu3=[-1.25;1.25]
S3=[.1 0;0 .1]
data3=mvnrnd(mu3,S3,100)
% 顯示資料
plot(data1(:,1),data1(:, 2),'b+');
hold on;
plot(data2(:,1),data2(:,2),'r+');
plot(data3(:,1),data3(:,2),'g+');
grid on; % 在畫圖的時候新增網格
% 三類資料合成一個不帶標號的資料類
data=[data1;data2;data3];
N=3; % 設定聚類數目k的值
[m,n]=size(data); % 300x2矩陣
pattern=zeros(m,n+1);
center=zeros(N,n); % 初始化聚類中心 3x2
pattern(:,1:n)=data(:,:);
for x=1:N
	center(x,:)=data(randi(300,1),:); % 第一次隨機產生聚類中心
end
while 1
	distance=zeros(1,N);
	num=zeros(1,N);
	new_center=zeros(N,n);
	for x=1:m
		for y=1:N 
			% 這裡使用的是歐氏距離
			distance(y)=norm(data(x,:)-center(y,:)); % 計算每個樣本到每個類中心的距離 
		end
		% min函式有三種呼叫形式 
		% min(A): 返回一個行向量,是每列最小值
		% [Y, U]=min(A): 返回行向量Y和U, Y向量記錄A的每列的最小值,U向量記錄每列最小值的行號
		% min(A, dim): dim取1或2.dim取1時,該函式與max(A)完全相同;dim為2時,返回列向量,其第
		% i個元素是A矩陣的第i行上的最小值
		[~,temp]=min(distance);
		pattern(x,n+1)=temp;
	end
	k=0;
	for y=1:N
		% 遍歷所有樣本,找到屬於第y類的樣本,並重新計算簇中心
		for x=1:m
			if pattern(x,n+1)==y
				new_center(y,:)=new_center(y,:)+pattern(x,1:n)
				num(y)=num(y)+1 % 計算第y類的所屬樣本個數,便於求均值
			end
		end
		new_center(y,:)=new_center(y,:)/num(y);
		if norm(new_center(y,:)-center(y,:))<0.1
			k=k+1
		end
	end
	if k==N
		break;
	else
		center=new_center
	end
end
[m,n]=size(pattern) % m=300,n=3

% 顯示聚類後的資料
figure;
hold on;
for i=1:m
	% 屬於不同類別的樣本畫成不一樣的r、b、g顏色的*狀,最終的簇中心為黑色圓圈
	if pattern(i,n)==1
		plot(pattern(i,1),pattern(i,2),'r*');
		plot(center(1,1),center(1,2),'ko');
	elseif pattern(i,n)==2
		plot(pattern(i,1),pattern(i,2),'g*');
		plot(center(2,1),center(2,2),'ko');
	elseif pattern(i,n)==3
		plot(pattern(i,1),pattern(i,2),'b*');
		plot(center(3,1),center(3,2),'ko');
	end
end 
grid on;

實驗結果:

原始資料raw data                                                           由於起始的簇中心選取很重要,因此效能也會變化,下面是不好的結果

良好的聚類結果:

 

下面給出了簡單的分類演算法KNN(用於分類的資料github下載):

clear all;close all;clc;
load data; % 讀取資料
Data = data;
% ?100個樣本進行歸?化處? min-max標準化方法[0,1]區間
% 缺點是當有新的資料加入時,max和min可能發生變化,需要重新定?
for i=1:100 
	for j=1:3
		Data(i,j)=(data(i,j)-min(data(:,j)))/(max(data(:,j))-min(data(:,j)));
	end
end
D1=Data(1:80,:);
D2=Data(81:100,:);
k=20; % 訓練集是80個樣本,測試集是20個樣?
for i=1:20
	temp=D2(i,1:3)
	for j=1:80 % 計算每個測試樣本到訓練樣本的距離向量
		distance(j)=norm(temp-D1(j,1:3));
	end
	[distance1,index]=sort(distance); %升序排列
	In=index(1:k); % 統計周圍20個訓練樣本的類別情況
	l1=length(find(D1(In,4)==1));
	l2=length(find(D1(In,4)==2));
	l3=length(find(D1(In,4)==3));
	[maxl,class(i)]=max([l1,l2,l3]); % class(i)是每個樣本所屬類別的20行向?
end
class
ratio=length(find((class'-D2(:,4))==0))/20 % 統計正確率是%90

可以看到分類效果不是很好,後面我們會利用PCA和異常資料監測來提高分類的效能。這裡只是介紹最簡單的演算法

 

總結:

K-means缺點:

(1) K是事先給定的,K值的選定是非常難以估計的

(2) 對異常資料很敏感。在計算質心的過程中,如果某個資料很異常,在計算均值的

    時候,會對結果影響非常大

 

KNN優缺點:

優點:

    精度高,對異常值不敏感,無資料輸入假定

缺點:

    計算量大,樣本不平衡時分類結果誤差大

改進: 事先對樣本進行剪輯,除去對分類作用不大的樣本權值的方法(和該樣本距離小的鄰域權值大)