1. 程式人生 > 其它 >K近鄰演算法(KNN)

K近鄰演算法(KNN)

1. k近鄰演算法(K-Nearest Neighbor,KNN)

  K最近鄰(k-Nearest Neighbor,KNN)分類演算法,是一個理論上比較成熟的方法,也是最簡單的機器學習演算法之一。該方法的思路是:在特徵空間中,如果一個樣本附近的k個最近(即特徵空間中最鄰近)樣本的大多數屬於某一個類別,則該樣本也屬於這個類別。如下圖所示:

2. 距離函式的定義

  在多維空間中,KNN使用的是歐氏距離度量周圍樣本距離預測樣本的距離。公式如下:

3. KNN實現鳶尾花(Iris)分類

  先看一下,鳶尾花長這樣子。

  鳶尾花資料集記錄了三類花以及它們的四種屬性。(四種屬性:花萼長度,花萼寬度,花瓣長度,花瓣寬度;3種標籤:Setosa,versicolor,virginica)

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

iris_dataset = load_iris()  # 載入資料集
feature = iris_dataset['data']
target = iris_dataset['target']
x_train, x_test, y_train, y_test = train_test_split(feature, target, test_size=0.2, random_state=0)
print(x_train) print(y_train)#標籤0,1,2表示三種不同的花卉(山鳶尾、變色鳶尾、維吉尼亞鳶尾) knn = KNeighborsClassifier(n_neighbors=3) knn = knn.fit(x_train, y_train) y_pred = knn.predict(x_test) print('模型的分類結果:', y_pred) print('真實的分類結果:', y_test) predict_result=knn.predict([[6.1, 3.1, 4.7, 2.1]]) print(predict_result)#結果是2,維吉尼亞鳶尾
4.KNN實現手寫數字識別分類

 手寫數字中用到的資料集為minist手寫數字資料集,每張圖片大小為28*28灰度圖片。訓練集為6萬張,測試集為1萬張,類別為0~9。

  一共包含四個檔案:

  train-images-idx3-ubyte.gz:訓練集影象(9912422 位元組)55000張訓練集 + 5000張驗證集;

  train-labels-idx1-ubyte.gz:訓練集標籤(28881 位元組)訓練集對應的標籤;

  t10k-images-idx3-ubyte.gz:測試集影象(1648877 位元組)10000張測試集;

  t10k-labels-idx1-ubyte.gz:測試集標籤(4542 位元組)測試集對應的標籤;

  比如讀取其中一個數據,轉換為向量形式,這裡我把向量放在了txt中,如圖所示,很顯然代表數字0.

  由於是灰度圖(一個通道),數字範圍為0~255,進一步,可以講向量表示的影象進行閾值二值化,變為0與1,如下圖:

  程式碼實現:

import gzip
import numpy as np
import matplotlib.pyplot as plt
import cv2

def readImage(path,num_images=1):
    f=gzip.open(path,'r')
    image_size = 28
    f.read(16)#前16個位元組儲存的是magic number,number of images,number of rows ,number of columns
    buf=f.read(image_size * image_size * num_images)
    data=np.frombuffer(buf,dtype=np.uint8).astype(np.float32)#ndarray型別
    print("data",type(data))
    data=data.reshape(num_images,image_size,image_size,1)
    image = np.asarray(data).squeeze()
    return image


def readLabel(path,num_images=1):
    f = gzip.open(path,'r')
    f.read(8)
    labels=[]
    for i in range(0,num_images):
        buf = f.read(1)
        label = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
        labels.append(label)
    return  labels

def calEuclidean(x, y): #計算兩向量的歐氏距離
    dist = np.sqrt(np.sum(np.square(x-y)))   # 注意:np.array 型別的資料可以直接進行向量、矩陣加減運算。np.square 是對每個元素求平均
    return dist

def knn(test,train,trainLabel,k):   #給定任意一個圖片的向量採用k近鄰求得所屬類別,k表示距離最近的K的元素
    distAll=[]
    for i in range(0,len(train)):
        dist=calEuclidean(test,train[i])
        distAll.append(dist)
    #對距離排序
    # print(distAll)
    distAllTuple = sorted(zip(distAll, range(len(distAll))))
    distAllTuple.sort(key=lambda t: t[0])  #
    sort_distAll_position = [x[1] for x in distAllTuple]  # 得到排序後原來的下標,方便後面找label
    # print(sort_distAll_position)
    #計算前k個哪個類別出現的最多
    labelNum={}
    for i in range(0,k):
        if str(trainLabel[sort_distAll_position[i]][0]) not in labelNum:
            labelNum[str(trainLabel[sort_distAll_position[i]][0])]=1
        else:
            labelNum[str(trainLabel[sort_distAll_position[i]][0])]+=1
    print("在最近的k個元素中,每個類別出現的頻率為",labelNum)
    return max(labelNum,key=labelNum.get)#返回數字最大所對應的類別

def accuracy(predictResult,testlabels):
    all = len(predictResult)
    TP = 0
    for i in range(0,len(predictResult)):
        if str(testlabels[i][0])== predictResult[i]:
            TP+=1
    return TP/all


if __name__=="__main__":
    path1='MNIST_data/train-images-idx3-ubyte.gz'
    # num_images=60000
    num_images = 1000
    imageArray=readImage(path1,num_images)#得到的是num個影象的向量
    # retval, dst = cv2.threshold(imageArray[1], 50, 1, cv2.THRESH_BINARY)#可以對影象閾值處理,處理第2個圖片為二值影象
    retval, dst = cv2.threshold(imageArray, 50, 1, cv2.THRESH_BINARY)  # 可以對所有影象閾值處理,處理為二值影象
    # plt.imshow(imageArray,cmap='Greys_r')
    # plt.imshow(dst, cmap='Greys_r')
    # print(dst)
    # plt.show()
    labels=readLabel('MNIST_data/train-labels-idx1-ubyte.gz',1000)
    # print(labels[1])#類別為0
    imageArraytest = readImage('MNIST_data/t10k-images-idx3-ubyte.gz', num_images=100)#讀取100個測試圖片
    testlabels=readLabel('MNIST_data/t10k-labels-idx1-ubyte.gz',100)#讀取100個測試圖片的標籤
    retval, dsttest = cv2.threshold(imageArraytest, 50, 1, cv2.THRESH_BINARY)  # 可以對影象閾值處理,處理為二值影象
    # plt.imshow(imageArray,cmap='Greys_r')
    plt.imshow(dsttest[0], cmap='Greys_r')#顯示7的影象
    print(dsttest[0])
    plt.show()
    #隨便拿第一張圖片測試
    labelResult=knn(dsttest[0], dst, labels, 100)
    print("預測的結果為",labelResult)#輸出為7
    print("實際結果為",testlabels[0])
    #以下測試100個圖片,並計算準確率
    predictResult=[]
    for i in range(0,100):
        labelResult = knn(dsttest[i], dst, labels, 100)
        predictResult.append(labelResult)
    print(predictResult)
    Pre_accuracy=accuracy(predictResult, testlabels)
    print("準確率",Pre_accuracy)

  執行結果:

  最後可以根據取不同的K值,找到最優K,另外使用交叉驗證法能夠找到更好的K值。

5.  KNN分類演算法的優點與缺點

  優點:

    1.簡單,易於理解,易於實現,無需估計引數,無需訓練;

    2.適合對稀有事件進行分類;

    3.特別適合於多分類問題(multi-modal,物件具備多個類別標籤), KNN比SVM的表現要好。

  缺點:

    1.樣本不平衡時,會使結果出現誤差

    2.計算量龐大

    3.因為不須要訓練,因此演算法的可控性比較差

  附上資料與程式碼:  連結:https://pan.baidu.com/s/1_3tUaXqSsq3uKlx1o-FbPA     提取碼:mqff

  若存在不足或錯誤之處,歡迎指正與評論!

參考資料

             https://blog.csdn.net/qq_45603919/article/details/120478822

             https://blog.csdn.net/qq_42302831/article/details/102553007

             https://blog.csdn.net/asialee_bird/article/details/81051281

             http://www.javashuo.com/article/p-owdkjgnl-mc.html

             https://blog.csdn.net/asialee_bird/article/details/81051281

             https://www.cnblogs.com/zhangzhenw/p/14583195.html