【TensorFlow】3-2構建簡單單層神經網路進行【手寫字元識別】
阿新 • • 發佈:2018-12-20
自動下載並轉化MNIST資料集格式到TF中
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #載入MNIST資料集,如果不存在,將自動在預設網址下載,並被TF簡單處理 mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #匯入資料集後,60000個數據,自動分為了55000 train_data 10000 test_data. # 50000 train_data中有分出5000個作為訓練過程的valiadtion n_train =mnist.train.num_examples n_test =mnist.test.num_examples n_validtion=mnist.validation.num_examples image_data =mnist.train.images[0] label_data =mnist.train.labels[0] print ("n_train,n_validtion,n_test",n_train,n_validtion,n_test) #輸出(55000,5000,10000) print("image_data",image_data) #輸出28*28的矩陣,值為[0,1]之間 print("image_label",label_data) #輸出[0.,0.,0.,0.,0.,0.,0.,1.,0.,0.] #資料集從資料集中提取出一小部分隨機訓練 batch_size=100 xs,ys=mnist.train.next_batch(batch_size) print("X shape:",xs.shape) #輸出 [100,784] print("Y shape:",ys.shape) #輸出 [100,10]
完整程式碼
# -*- coding: utf-8 -*- import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #載入MNIST資料集,如果不存在,將自動在預設網址下載 mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #匯入資料集後,60000個數據,自動分為了55000 train_data 10000 test_data. # 50000 train_data中有分出5000個作為訓練過程的valiadtion n_train =mnist.train.num_examples n_test =mnist.test.num_examples n_validtion=mnist.validation.num_examples image_data =mnist.train.images[0] label_data =mnist.train.labels[0] #print ("n_train,n_validtion,n_test",n_train,n_validtion,n_test) ##輸出(55000,5000,10000) #print("image_data",image_data) ##輸出28*28的矩陣,值為[0,1]之間 #print("image_label",label_data) ##輸出[0.,0.,0.,0.,0.,0.,0.,1.,0.,0.] #資料集從資料集中提取出一小部分隨機訓練 batch_size=100 xs,ys=mnist.train.next_batch(batch_size) print("X shape:",xs.shape) #輸出 [100,784] print("Y shape:",ys.shape) #輸出 [100,10] #定義簡單的神經網路,784個輸入節點,10個輸出神經元,即單層神經網路 INPUT_NODE = 784 OUTPUT_NODE= 10 x=tf.placeholder(tf.float32,[None,INPUT_NODE]) y=tf.placeholder(tf.float32,[None,OUTPUT_NODE]) W1=tf.Variable(tf.zeros([INPUT_NODE,OUTPUT_NODE])) b1=tf.Variable(tf.zeros([OUTPUT_NODE])) prediction=tf.nn.softmax(tf.matmul(x,W1)+b1) #二次代價函式 loss=tf.reduce_mean(tf.square(y-prediction)) #使用梯度下降法優化 train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss) #將預測結果與便籤結果的對比值存於布林列表裡 #tf.arg_max按行輸入最大值的索引 correct_prediction=tf.equal(tf.arg_max(prediction,1),tf.arg_max(y,1)) #求神經網路的準確率 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #訓練神經網路 with tf.Session() as sess: #初始化變數 sess.run(tf.global_variables_initializer()) for epoch in range(10000): xs,ys=mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:xs,y:ys}) if epoch % 1000==0: acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print("epoch:"+str(epoch)+",Test Accuracy"+str(acc))
訓練資料,驗證資料和測試資料關係【參考連結】
- 訓練資料(Training Data):用於模型構建
- 驗證資料(Validation Data):可選,用於輔助模型構建,可以重複使用。
- 測試資料(Test Data):用於檢測模型構建,此資料只在模型檢驗時使用,用於評估模型的準確率。絕對不允許用於模型構建過程,否則會導致過渡擬合。
tensorflow tf.argmax() 用法講解
- tf.argmax(input, axis=None, name=None, dimension=None)
- 對矩陣按行或列計算最大值
- 四個引數:
- 1).input:輸入值
- 2).axis:可選值0表示按列,1表示按行求最大值。axis兩個引數的區別是:0是每個陣列對應位置之間的比較,而1則是陣列內部元素之間的比較。
- 3).name
- 4).預設使用axis即可.