1. 程式人生 > >tenserflow例項之最近鄰演算法

tenserflow例項之最近鄰演算法

一、概念&意義

鄰近演算法,或者說K最近鄰(kNN,k-NearestNeighbor)分類演算法是資料探勘分類技術中最簡單的方法之一。所謂K最近鄰,就是k個最近的鄰居的意思,說的是每個樣本都可以用它最接近的k個鄰居來代表。kNN演算法的核心思想是如果一個樣本在特徵空間中的k個最相鄰的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,並具有這個類別上樣本的特性。該方法在確定分類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。 kNN方法在類別決策時,只與極少量的相鄰樣本有關。由於kNN方法主要靠周圍有限的鄰近的樣本,而不是靠判別類域的方法來確定所屬類別的,因此對於類域的交叉或重疊較多的待分樣本集來說,kNN方法較其他方法更為適合。

二、演算法流程

1.準備資料,對資料進行預處理

2.選用合適的資料結構儲存訓練資料和測試元組

3.設定引數,如k

4.維護一個大小為k的的按距離由大到小的優先順序佇列,用於儲存最近鄰訓練元組。隨機從訓練元組中選取k個元組作為初始的最近鄰元組,分別計算測試元組到這k個元組的距離,將訓練元組標號和距離存入優先順序佇列

5.遍歷訓練元組集,計算當前訓練元組與測試元組的距離,將所得距離L 與優先順序佇列中的最大距離Lmax

6.進行比較。若L>=Lmax,則捨棄該元組,遍歷下一個元組。若L < Lmax,刪除優先順序佇列中最大距離的元組,將當前訓練元組存入優先順序佇列。

7.遍歷完畢,計算優先順序佇列中k 個元組的多數類,並將其作為測試元組的類別。

8.測試元組集測試完畢後計算誤差率,繼續設定不同的k值重新進行訓練,最後取誤差率最小的k 值。

三、距離度量


a)當p=1時,為曼哈頓距離;

b)當p=2時,為歐式距離;

四:利用MNIST data做的一個KNNf分類

import numpy as np
import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

#選取訓練集、測試集數目,這裡Ytr是一個10維的ndarray,即0~9


Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)
Xte, Yte = mnist.test.next_batch(200) #200 for testing

#定義變數大小
xtr = tf.placeholder("float", [None, 784])
xte = tf.placeholder("float", [784])


#計算測試資料與訓練資料L1範數大小(1表示從橫軸進行降維) 
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)

#求得distance最小的下標(0表示從豎軸計算)  
pred = tf.arg_min(distance, 0)
accuracy = 0.

# Initializing the variables
init = tf.global_variables_initializer()

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    # loop over test data
    for i in range(len(Xte)):
     #近鄰演算法:測試集與訓練集對比,返回誤差最小的下標  
        nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})
        # Get nearest neighbor class label and compare it to its true label
        print("Test", i, "Prediction:", np.argmax(Ytr[nn_index]), \
            "True Class:", np.argmax(Yte[i]))
        # Calculate accuracy
        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
            accuracy += 1./len(Xte)
    print("Done!")
    print("Accuracy:", accuracy)