1. 程式人生 > 實用技巧 >python實現簡單的神經網路_mnist資料集神經網路實現

python實現簡單的神經網路_mnist資料集神經網路實現

實現流程

1、準備資料

2、全連線結果計算

3、損失優化(梯度下降)

4、模型評估(計算準確性)

5、加入tensorboard圖

 1 def full_connect():
 2     #使用佔位符時,tersorflow2.X以上會出現tf.placeholder() is not compatible with eager execution報錯,需要加下面這段語,避免程式報此錯誤。
 3     tf.compat.v1.disable_eager_execution()
 4     #獲取真實的資料
 5     mnist = input_data.read_data_sets("
./tmp/mnist/", one_hot=True) 6 #1、建立資料的佔位符 ,X[None,784] y_true [None,10] 7 with tf.compat.v1.variable_scope('data'): 8 x=tf.compat.v1.placeholder(tf.float32,[None,784]) 9 y_true=tf.compat.v1.placeholder(tf.int32,[None,10]) 10 11 #2、建立一個全連結層的神經網路 w[784,10],b=[10] 12 with tf.compat.v1.variable_scope('
fc_model'): 13 #隨機初始化權重和偏置,權重和偏置後面會跟著訓練自動優化 14 weight=tf.Variable(tf.compat.v1.random_normal([784,10],mean=0.0,stddev=1.0),name='weight') 15 bias=tf.Variable(tf.constant(0.0,shape=[10])) 16 #預測Nonew個樣本的輸出結果matrix [None,784]*[784*10]+[10]=[None,10],即矩陣[None,784]樣本的特徵*權重[784,10]+偏置[10]=預測結果[None,10]
17 y_predict=tf.matmul(x,weight)*bias 18 #計算交叉熵損失 19 with tf.compat.v1.variable_scope('soft_cross'): 20 #返回交叉熵的列表結果,對交叉熵求平均值 21 loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict)) 22 23 #梯度下降求出損失 24 with tf.compat.v1.variable_scope('optimizer'): 25 train_op=tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize(loss) 26 #5、計算準確率,預測準確置為1 27 with tf.compat.v1.variable_scope('acc'): 28 #equal_list None個樣本[1,0,1,1,.....] 29 equal_list=tf.equal(tf.argmax(y_true,1),tf.argmax(y_predict,1)) 30 accuray=tf.reduce_mean(tf.cast(equal_list,tf.float32)) 31 #收集變數,單個數字值收集 32 tf.compat.v1.summary.scalar("losses",loss) 33 tf.compat.v1.summary.scalar("acc", accuray) 34 35 #高緯度變數收集 36 tf.compat.v1.summary.histogram('weight',weight) 37 tf.compat.v1.summary.histogram('biases',bias) 38 39 #定義一個合併的op 40 merged=tf.compat.v1.summary.merge_all() 41 42 #因為有變數,故要定義初始化變數的op 43 init_op=tf.compat.v1.global_variables_initializer() 44 #開啟回話去訓練 45 with tf.compat.v1.Session() as sess: 46 #初始化變數 47 sess.run(init_op) 48 filewriter=tf.compat.v1.summary.FileWriter('./tmp/summary/test/',graph=sess.graph) 49 #迭代步數去訓練 ,更新引數預測 50 for i in range(2000): 51 mnist_x,mnist_y=mnist.train.next_batch(50) 52 #feed_dict實時提供的資料 x訓練集,y為真實的目標值 53 #執行op訓練 54 sess.run(train_op,feed_dict={x:mnist_x,y_true:mnist_y}) 55 #寫入每步訓練的值 56 summary=sess.run(merged,feed_dict={x:mnist_x,y_true:mnist_y}) 57 filewriter.add_summary(summary,i) 58 59 print('訓練第%d步,準確率為:%f'%(i,sess.run(accuray,feed_dict={x:mnist_x,y_true:mnist_y}))) 60 return None

注意:在tensorflow2.X版本,如果出現報No module named 'tensorflow.examples.tutorials' ,手動下載tutorials檔案包,並放到本地電腦tersorflow/examples目錄下。

下載連結:https://share.weiyun.com/fpYSBj4X 密碼:qu73et

出現報錯tensorflow報AttributeError: __enter__,將tf.compat.v1.Session後面加上括號()