昨天學習了一些tensorflow入門知識,經歷各種奇葩錯誤,現在奉獻一份安裝tensorflow2就可以跑的demo
阿新 • • 發佈:2020-08-14
本程式使用minist影象集合作為資料來源,使用tensorflow內部的資料載入方式(如果沒有資料集,會自動從網上下載).神經網路內層有三層,依靠純手工搭建網路模式,比較貼近數學模型
1 #encoding: utf-8 2 import os 3 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 關閉低階的除錯資訊 4 5 import tensorflow as tf 6 from tensorflow import keras 7 from tensorflow.keras import datasets 8 9 lr = 0.001 10(x,y), _ = datasets.mnist.load_data() 11 12 # x 歸一化 13 x = tf.convert_to_tensor(x,dtype=tf.float32)/255. 14 y = tf.convert_to_tensor(y,dtype=tf.int32) 15 print(x.shape,y.shape,x.dtype,y.dtype) 16 print(tf.reduce_min(x),tf.reduce_max(x)) 17 print(tf.reduce_min(y),tf.reduce_max(y)) 18 19 # 將源資料分割,一次處理128條資料。分60000/128此處理結束,返回一個迭代器20 train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128) 21 train_iter = iter(train_db) 22 sample = next(train_iter) 23 print("batch:",sample[0].shape,sample[1].shape) 24 25 #生成引數矩陣 注意tf.Variable 大小寫 26 w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1)) 27 b1 = tf.Variable(tf.zeros([256]))28 w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1)) 29 b2 = tf.Variable(tf.zeros([128])) 30 w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1)) 31 b3 = tf.Variable(tf.zeros([10])) 32 33 for epoch in range(30): #迭代一次即完成60k迴圈 34 for step,(x,y) in enumerate(train_db): #迭代一次即完成一次128次訓練 35 # x =>128,28,28 需要轉換為128,28*28 36 x = tf.reshape(x,[-1,28*28]) 37 # 注意加() 38 with tf.GradientTape() as tape: 39 h1 = x@w1 + b1 40 h1 = tf.nn.relu(h1) 41 h2 = h1@w2 + b2 42 h2 = tf.nn.relu(h2) 43 out = h2@w3 + b3 44 45 # 要對y進行one_hot編碼,好處之一,使得不同結果之前的距離可以保持一致 46 y_one_hot = tf.one_hot(y,depth=10) 47 loss = tf.square(y_one_hot-out) 48 loss = tf.reduce_mean(loss) 49 50 grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3]) 51 # w1 = w1 - lr*grads[0],這個更新方式會更改原來的值,不再是tf.Variable 52 # 此處採用原地更新的方式,如下 (注意函式名字不要寫錯) 53 w1.assign_sub(lr * grads[0]) 54 b1.assign_sub(lr * grads[1]) 55 w2.assign_sub(lr * grads[2]) 56 b2.assign_sub(lr * grads[3]) 57 w3.assign_sub(lr * grads[4]) 58 b3.assign_sub(lr * grads[5]) 59 60 if step%100 == 0: 61 print(epoch,step,"loss:",float(loss))