1. 程式人生 > >完整神經網路的樣例程式

完整神經網路的樣例程式

# 總結:
# 訓練神經網路的過程可以分為以下三個步驟:
# 1、定義神經網路的結構和前向傳播的輸出結果
# 2、定義損失函式以及選擇反向傳播優化的演算法
# 3、生成會話(tf.Session)並且在訓練資料上反覆執行反向傳播優化演算法



import tensorflow as tf
#NumPy是一個科學計算的工具包,這裡通過NumPy生成模擬資料集
from numpy.random import RandomState

#定義訓練資料batch的大小
batch_size = 8

#定義神經網路的引數
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))

#在shape的一個維度上使用None可以方便使用不大的batch大小。在訓練師要把資料
# 分成比較小的batch,但在測試時,可以一次性使用全部的資料。當資料及比較小時
# 這樣比較方便測試,但資料集比較大時,將大量資料放入一個batch可能會導致記憶體溢位。
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_ = tf.placeholder(tf.float32, shape=[None, 1], name="y-input")

#定義神經網路向前傳播的過程。
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

#定義損失函式和反向傳播演算法。
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
#Tensorflow經過使用Adam優化演算法對損失函式中變數進行修改值
#預設修改tf.Variable型別的引數。
#也可以使用var_list引數來定義更新哪些引數
# 如:train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy,var_list=[w1, w2])


#通過隨機數生成一個模擬資料集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)

#定義規則來給出樣本的標籤。在這裡所有x1+x2<1的樣例都被認為是正樣本
# 而其他為負樣本。和Tensorflow遊樂場中的表示法不大一樣的地方是,
# 在這裡用0來表示負樣本,1表示正樣本。大部分解決分類問題的神經網路都會採用
# 0和1的表示方法
Y = [[int(x1+x2 < 1)] for (x1, x2) in X]

#建立一個回話來執行Tensorflow程式
with tf.Session() as sess:
    iniy_op = tf.global_variables_initializer()
    #初始化變數
    sess.run(iniy_op)
    print(sess.run(w1))
    print(sess.run(w2))
    STEPS = 5000
    for i in range(STEPS):
        #每次選取batch_size個樣本進行訓練
        start = (i*batch_size) % dataset_size
        end = min(start+batch_size, dataset_size)
        #通過選取的樣本訓練神經網路並更新引數

        sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
        if i % 1000 == 0:
            #每隔一段時間計算在所有資料上的交叉熵並輸出。
            total_cross_entropy = sess.run(
                cross_entropy, feed_dict={x: X, y_: Y})
            print("After %d training step(s), cross entropy on all data is %g"%(i, total_cross_entropy))
    print(sess.run(w1))
    print(sess.run(w2))