機器學習:KNN演算法(MATLAB實現)
阿新 • • 發佈:2019-02-12
K-近鄰演算法的思想如下:首先,計算新樣本與訓練樣本之間的距離,找到距離最近的K 個鄰居;然後,根據這些鄰居所屬的類別來判定新樣本的類別,如果它們都屬於同一個類別,那麼新樣本也屬於這個類;否則,對每個後選類別進行評分,按照某種規則確定新樣本的類別。(統計出現的頻率)
該演算法比較適用於樣本容量比較大的類域的自動分類,而那些樣本容量較小的類域採用這種演算法比較容易產生誤分當K值較小時可能產生過擬合,因為訓練誤差很小,但是測試誤差可能很大;相反,當K值較大時可能產生欠擬合。
演算法虛擬碼
對未知類別屬性的資料集中的每個點依次執行以下操作:
(1) 計算已知類別的資料集中的點與當前點之間的距離;
(2) 按照距離遞增次序排序;
(3) 選取與當前點距離最小的K個點;
(4) 確定前K個點所在類別的出現頻率;
(5) 返回前K個點出現頻率最高的類別作為當前點的預測分類。
- %
- %手寫數字識別系統的測試程式碼
- %
- function handWritingTest()
- tic; %開始計時
- K = 3; % 這裡可以調整k值
- trainLabels = [];
- direct = mfilename('fullpath');%
- traindirect = strrep(direct,'handWritingTest','trainingDigits'); %trainingDigits
- %獲得路徑
- traindirfile = dir(fullfile(traindirect,'*.txt'));%提取字尾名.txt
- traindircell = struct2cell(traindirfile)';
- trainfilenames = traindircell(:,1);
- trainfileNums = length(trainfilenames);
- trainMat = zeros(trainfileNums,1024);
- for i = 1:trainfileNums
- fileNameStr = trainfilenames(i);
- str = deblank(fileNameStr);
- s = regexp(str,'\.','split'); %
- fileStr = s{1}(1);
- classNumStr = regexp(fileStr,'\_','split');
- trainLabels(i)=str2num(char(classNumStr{1}(1))); %得到類別 0 - 9
- filePath = strcat(traindirect,'\',fileNameStr); %檔案路徑
- trainMat(i,:) = img2vector(filePath);%處理檔案 獲得向量
- end
- %測試樣本
- direct = mfilename('fullpath');
- testdirect = strrep(direct,'handWritingTest','testDigits');%testDigits
- testdirfile = dir(fullfile(testdirect,'*.txt'));
- testdircell = struct2cell(testdirfile)';
- testfilenames = testdircell(:,1);
- testfileNums = length(testfilenames);
- errorcount = 0;
- for j = 1:testfileNums
- fileNameStr = testfilenames(j);
- str = deblank(fileNameStr);
- s = regexp(str,'\.','split');
- fileStr = s{1}(1);
- classNumStr = regexp(fileStr,'\_','split');
- testLabel = str2num(char(classNumStr{1}(1))); %得到類別 0 - 9
- filePath = strcat(testdirect,'\',fileNameStr);
- testVector = img2vector(filePath);
- classifyRet = classify(testVector,trainMat,trainLabels,K);
- if(classifyRet ~= testLabel)
- errorcount = errorcount + 1;
- fprintf('test result: %d, real result: %d , here error!!! \n',classifyRet,testLabel);
- else
- fprintf('test result: %d, real result: %d \n',classifyRet,testLabel);
- end
- end
- lastTime = num2str(toc);
- fprintf('\n the sum numbers of errors : %d ',errorcount);
- fprintf('\n the total error rate : %f ' ,(errorcount / testfileNums));
- fprintf('\n total time : %f',lastTime);
- end
- %
- %KNN演算法 classify(test,dataSet,labels,k)
- %四個引數:test用於分類的輸入向量;輸入的訓練樣本集為dataSet;
- %標籤向量為labels; k 表示用於選擇最近鄰居的數目;
- %
- function maxClass = classify(test,dataSet,labels,k)
- [dataRow,dataCol] = size(dataSet);%dataRow:樣本個數;dataCol:特徵
- %求距離 test 與樣本資料之間的距離 這裡為歐式距離
- diffMat = dataSet;
- for i = 1:dataRow
- diffMat(i,:) = diffMat(i,:) - test;
- end
- sqdiffMat = diffMat.^2;
- sqDistances = sum(sqdiffMat,2).^(0.5);
- [p,q] = sort(sqDistances); %p代表要排序的數,q代表要排序的數原來對應的索引
- %通過k 來求最鄰居的前k 個數據,然後找的在這些資料中類別最多的
- classCount=zeros(10,1);
- class = [];
- for j = 1:k
- tempLabel = labels(q(j));
- class(j) = tempLabel;%沒用到
- classCount(tempLabel+1) = classCount(tempLabel+1)+1;
- end
- [r,s] = max(classCount);
- maxClass = s - 1; %返回 相似個數最多的 那個類
- end
- %
- %將32*32的二進位制圖形矩陣轉換為1*1024的向量
- %
- function retVector = img2vector(fileName)
- fileName = char(fileName);
- tempVector = [];
- % 讀檔案
- fileData = textread(fileName,'%s');
- fileData = char(fileData);%讀取檔案,並將檔案轉換矩陣的格式
- temp = fileData(:)';
- for i = 1 : length(temp)
- tempVector(i) = str2num(temp(i));
- end
- retVector = tempVector;
- end