1. 程式人生 > >簡單易學的機器學習演算法——K-Means演算法

簡單易學的機器學習演算法——K-Means演算法

一、聚類演算法的簡介

    聚類演算法是一種典型的無監督學習演算法,主要用於將相似的樣本自動歸到一個類別中。聚類演算法與分類演算法最大的區別是:聚類演算法是無監督的學習演算法,而分類演算法屬於監督的學習演算法。

    在聚類演算法中根據樣本之間的相似性,將樣本劃分到不同的類別中,對於不同的相似度計算方法,會得到不同的聚類結果,常用的相似度計算方法有歐式距離法。

二、K-Means演算法的概述

   基本K-Means演算法的思想很簡單,事先確定常數K,常數K意味著最終的聚類類別數,首先隨機選定初始點為質心,並通過計算每一個樣本與質心之間的相似度(這裡為歐式距離),將樣本點歸到最相似的類中,接著,重新計算每個類的質心(即為類中心),重複這樣的過程,知道質心不再改變,最終就確定了每個樣本所屬的類別以及每個類的質心。由於每次都要計算所有的樣本與每一個質心之間的相似度,故在大規模的資料集上,

K-Means演算法的收斂速度比較慢。

三、K-Means演算法的流程

  • 初始化常數K,隨機選取初始點為質心
  • 重複計算一下過程,直到質心不再改變
    • 計算樣本與每個質心之間的相似度,將樣本歸類到最相似的類中
    • 重新計算質心
  • 輸出最終的質心以及每個類

四、K-Means演算法的實現

    對資料集進行測試
原始資料集MATLAB程式碼主程式
%% input the data
A = load('testSet.txt');

%% 計算質心
centroids = kMeans(A, 4);

隨機選取質心
%% 取得隨機中心
function [ centroids ] = randCent( dataSet, k )
    [m,n] = size(dataSet);%取得列數
    centroids = zeros(k, n);
    for j = 1:n
        minJ = min(dataSet(:,j));
        rangeJ = max(dataSet(:,j))-min(dataSet(:,j));
        centroids(:,j) = minJ+rand(k,1)*rangeJ;%產生區間上的隨機數
    end
end

計算相似性
function [ dist ] = distence( vecA, vecB )
    dist = (vecA-vecB)*(vecA-vecB)';%這裡取歐式距離的平方
end

kMeans的主程式
%% kMeans的核心程式,不斷迭代求解聚類中心
function [ centroids ] = kMeans( dataSet, k )
    [m,n] = size(dataSet);
    %初始化聚類中心
    centroids = randCent(dataSet, k);
    subCenter = zeros(m,2);%做一個m*2的矩陣,第一列儲存類別,第二列儲存距離
    change = 1;%判斷是否改變
    while change == 1
        change = 0;
        %對每一組資料計算距離
        for i = 1:m
            minDist = inf;
            minIndex = 0;
            for j = 1:k
                 dist= distence(dataSet(i,:), centroids(j,:));
                 if dist < minDist
                     minDist = dist;
                     minIndex = j;
                 end
            end
            if subCenter(i,1) ~= minIndex
                change = 1;
                subCenter(i,:)=[minIndex, minDist];
            end        
        end
        %對k類重新就算聚類中心
        
        for j = 1:k
            sum = zeros(1,n);
            r = 0;%數量
            for i = 1:m
                if subCenter(i,1) == j
                    sum = sum + dataSet(i,:);
                    r = r+1;
                end
            end
            centroids(j,:) = sum./r;
        end
    end
    
    %% 完成作圖
    hold on
    for i = 1:m
        switch subCenter(i,1)
            case 1
                plot(dataSet(i,1), dataSet(i,2), '.b');
            case 2
                plot(dataSet(i,1), dataSet(i,2), '.g');
            case 3
                plot(dataSet(i,1), dataSet(i,2), '.r');
            otherwise
                plot(dataSet(i,1), dataSet(i,2), '.c');
        end
    end
    plot(centroids(:,1),centroids(:,2),'+k');
end

最終的聚類結果