從零開始 TensorFlow NN
阿新 • • 發佈:2018-12-24
L1距離,在訓練集中找出離測試資料最近的資料,比較他們的標籤
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)
Xtrain ,Ytrain = mnist.train.next_batch(5000)
Xtest, Ytest = mnist. test.next_batch(200)
xtrain=tf.placeholder('float',[None,784])
xtest=tf.placeholder('float',[784])
#L1 distance
distance=tf.reduce_sum(tf.abs(tf.add(xtrain,tf.negative(xtest))),reduction_indices=2)
#返回最小下標
predict=tf.argmin(distance,0)
accuracy=0.
init=tf.global_variables_initializer()
with tf.Session( ) as sess:
sess.run(init)
for i in range(len(Xtest)):
nn_index=sess.run(predict,feed_dict={xtrain:Xtrain,xtest:Xtest[i,:]})
print('第', i, '次預測', np.argmax(Ytrain[nn_index]), '真正的類別:', np.argmax(Ytest[i]))
if np.argmax(Ytrain[nn_index])==np.argmax(Ytest[i]):
accuracy+= 1./len(Xtest)
print('accuracy:',accuracy)