1. 程式人生 > >Tensorflow擼程式碼之3knn演算法

Tensorflow擼程式碼之3knn演算法

具體knn演算法概念參考knn程式碼python實現
上面是參考《機器學習實戰》的程式碼,和knn的思想

# _*_ encoding=utf8 _*_

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 匯入手寫體識別的資料
mnist = input_data.read_data_sets("../data", one_hot=True)

# 訓練集和測試集
X_train, Y_train = mnist.train.next_batch(
5000) # 資料和labels X_test, Y_test = mnist.test.next_batch(100) # 定義輸入 x_train = tf.placeholder(tf.float32, shape=(None,784)) x_test = tf.placeholder(tf.float32, shape=(784)) # L1距離也就是城市街區距離 |x1-x2|+|y1-y2| distance = tf.reduce_sum(tf.abs(tf.add(x_train,tf.negative(x_test))),reduction_indices = 1) # 返回最近的座標,0縱軸 1橫軸
pred = tf.arg_min(distance, 0) accuracy = 0 # 初始化 init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for i in range(len(X_test)): # 獲取當前樣本的最近鄰索引,當前樣本和每一個訓練的樣本找一個最近的l1距離,得到這個最小距離的下標 nn_index = sess.run(pred, feed_dict={x_train:X_train, x_test:
X_test[i, :]}) # 由最鄰近索引找到label,然後最鄰近的label與真實標籤比較 np.argmax找最大的下標 # 由l1距離找到的最小值對應的座標,通過該最座標找到對應行label的最大值的下標,這個下標對應的就是數字的大小 print("預測次數", i, "預測標籤:", np.argmax(Y_train[nn_index]),"真實標籤:", np.argmax(Y_test[i])) # 計算準確率 if np.argmax(Y_train[nn_index]) == np.argmax(Y_test[i]): accuracy += 1 print("Accuracy:", float(accuracy)/len(X_test))