TensorFlow(八) TensorFlow圖像識別(KNN)
阿新 • • 發佈:2018-06-29
nump session trac inf dict sha ceil dom 數據
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from sklearn import datasets import random from PIL import Image from tensorflow.examples.tutorials.mnist import input_data sess=tf.Session() mnist= input_data.read_data_sets("MNIST_data/",one_hot=True) #本例包含10個類別 train_size=1000 test_size=102 rand_train_indices=np.random.choice(len(mnist.train.images),train_size,replace=False) rand_test_indices=np.random.choice(len(mnist.train.images),test_size,replace=False) x_vals_train=mnist.train.images[rand_train_indices] x_vals_test=mnist.train.images[rand_test_indices] y_vals_train=mnist.train.labels[rand_train_indices] y_vals_test=mnist.train.labels[rand_test_indices] k=4 batch_size=6 x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32) x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32) y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32) y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)#L1距離 shape=(6, 1000) sub.shape=(1000,784) - (6,1,10)=(6,1000,784) distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2) #top K (6, 4) top_k_xvals,top_k_indices=tf.nn.top_k(tf.negative(distance),k=k) #(6, 4, 10) = gather((1000,10),(6,4) ) prediction_indices=tf.gather(y_target_train,top_k_indices) #shape=(6, 10) count_of_prediction=tf.reduce_sum(prediction_indices,reduction_indices=1) #預測模型 shape=(6,) prediction=tf.arg_max(count_of_prediction,dimension=1) num_loop=int(np.ceil(len(x_vals_test)/batch_size)) test_output=[] actual_vals=[] for i in range(num_loop): min_index=i*batch_size max_index=min((i+1)*batch_size,len(x_vals_test)) #獲取數據 x_batch=x_vals_test[min_index:max_index] y_batch = y_vals_test[min_index:max_index] predictions=sess.run(prediction,feed_dict={x_data_test:x_batch,x_data_train:x_vals_train,y_target_test:y_batch,y_target_train:y_vals_train}) test_output.extend(predictions) actual_vals.extend(np.argmax(y_batch,axis=1)) #精確度預測 accuracy=sum( 1./test_size for i in range(test_size) if test_output[i]==actual_vals[i]) print("Accuarcy: "+str(accuracy)) actuals=np.argmax(y_batch,axis=1) for i in range(len(actuals)): plt.subplot(2,3,i+1) plt.imshow(np.reshape(x_batch[i],[28,28]),cmap="Greys_r") plt.title(‘Actual: ‘+str(actuals[i])+‘ Pred:‘+str(predictions[i]),fontsize=10) frame=plt.gca() frame.axes.get_xaxis().set_visible(False) frame.axes.get_yaxis().set_visible(False) plt.show()
TensorFlow(八) TensorFlow圖像識別(KNN)