1. 程式人生 > >TensorFlow入門(5)——搭建一個簡單的全連線網路

TensorFlow入門(5)——搭建一個簡單的全連線網路

本文參考了https://blog.csdn.net/login_sonata/article/details/77620328 的原始碼,並加以修改。

修改為輸入層節點數為2,隱藏層為10個節點,輸出層節點數為2。

'''
Created on 2018年4月9日

@author: yqy
'''

import tensorflow as tf
import numpy as np

# 新增神經層的函式,它有四個引數:輸入值、輸入的形狀、輸出的形狀和激勵函式,
# Wx_plus_b是未啟用的值,函式返回啟用值。
def add_layer(inputs, in_size, out_size,  activation_function=None):
    # tf.random_normal()引數為shape,還可以指定均值和標準差
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    Wx_plus_b = tf.matmul(inputs, Weights) + biases#要求輸入向量x是個行向量
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_function(Wx_plus_b)
    return outputs

# 構建訓練資料
# np.linspace()在-1和1之間等差生成300個數字
# noise是正態分佈的噪聲,前兩個引數是正態分佈的引數,然後是size
x_data = np.linspace(-1,1,300, dtype=np.float32)[:, np.newaxis]#將x_data轉換為列向量
noise = np.random.normal(0, 0.05,  x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5 + noise

# 利用佔位符定義我們所需的神經網路的輸入。
# 第二個引數為shape:None代表行數不定,2是列數。
# 這裡的行數就是樣本數,列數是每個樣本的特徵數。
xs = tf.placeholder(tf.float32, [None, 2])
ys = tf.placeholder(tf.float32, [None, 2])

# 輸入層2個神經元(每次輸入兩個相同的x,輸出兩個相同的y),隱藏層10個,輸出層2個。
# 呼叫函式定義隱藏層和輸出層,輸入size是上一層的神經元個數(全連線),輸出size是本層個數。
l1 = add_layer(xs, 2, 10, activation_function=tf.nn.relu)
prediction = add_layer(l1, 10, 2, activation_function=None)

# 計算預測值prediction和真實值的誤差,對二者差的平方求和再取平均作為損失函式。
# reduction_indices表示最後資料的壓縮維度,好像一般不用這個引數(即降到0維,一個標量)。
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step =  tf.train.GradientDescentOptimizer(0.2).minimize(loss)

# 初始化變數,啟用,執行運算
init = tf.global_variables_initializer() 
with tf.Session() as sess:
    sess.run(init)
    print(x_data.shape)
    
    for i in range(1000):
        # training
        #使用np.concatenate將原來的一列x合併為兩列每行都相同的x
        sess.run(train_step,feed_dict={xs:np.concatenate((x_data,x_data),axis=1),ys:np.concatenate((y_data,y_data),axis=1)})
        if i % 50 == 0:
            print (sess.run(loss,feed_dict={xs:np.concatenate((x_data,x_data),axis=1),ys:np.concatenate((y_data,y_data),axis=1)}))