TensorFlow入門(5)——搭建一個簡單的全連線網路
阿新 • • 發佈:2019-01-23
本文參考了https://blog.csdn.net/login_sonata/article/details/77620328 的原始碼,並加以修改。
修改為輸入層節點數為2,隱藏層為10個節點,輸出層節點數為2。
''' Created on 2018年4月9日 @author: yqy ''' import tensorflow as tf import numpy as np # 新增神經層的函式,它有四個引數:輸入值、輸入的形狀、輸出的形狀和激勵函式, # Wx_plus_b是未啟用的值,函式返回啟用值。 def add_layer(inputs, in_size, out_size, activation_function=None): # tf.random_normal()引數為shape,還可以指定均值和標準差 Weights = tf.Variable(tf.random_normal([in_size, out_size])) biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) Wx_plus_b = tf.matmul(inputs, Weights) + biases#要求輸入向量x是個行向量 if activation_function is None: outputs = Wx_plus_b else: outputs = activation_function(Wx_plus_b) return outputs # 構建訓練資料 # np.linspace()在-1和1之間等差生成300個數字 # noise是正態分佈的噪聲,前兩個引數是正態分佈的引數,然後是size x_data = np.linspace(-1,1,300, dtype=np.float32)[:, np.newaxis]#將x_data轉換為列向量 noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32) y_data = np.square(x_data) - 0.5 + noise # 利用佔位符定義我們所需的神經網路的輸入。 # 第二個引數為shape:None代表行數不定,2是列數。 # 這裡的行數就是樣本數,列數是每個樣本的特徵數。 xs = tf.placeholder(tf.float32, [None, 2]) ys = tf.placeholder(tf.float32, [None, 2]) # 輸入層2個神經元(每次輸入兩個相同的x,輸出兩個相同的y),隱藏層10個,輸出層2個。 # 呼叫函式定義隱藏層和輸出層,輸入size是上一層的神經元個數(全連線),輸出size是本層個數。 l1 = add_layer(xs, 2, 10, activation_function=tf.nn.relu) prediction = add_layer(l1, 10, 2, activation_function=None) # 計算預測值prediction和真實值的誤差,對二者差的平方求和再取平均作為損失函式。 # reduction_indices表示最後資料的壓縮維度,好像一般不用這個引數(即降到0維,一個標量)。 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化變數,啟用,執行運算 init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print(x_data.shape) for i in range(1000): # training #使用np.concatenate將原來的一列x合併為兩列每行都相同的x sess.run(train_step,feed_dict={xs:np.concatenate((x_data,x_data),axis=1),ys:np.concatenate((y_data,y_data),axis=1)}) if i % 50 == 0: print (sess.run(loss,feed_dict={xs:np.concatenate((x_data,x_data),axis=1),ys:np.concatenate((y_data,y_data),axis=1)}))