tensorflow入門線性迴歸
阿新 • • 發佈:2018-11-02
實際上編寫tensorflow可以總結為兩步.
(1)組裝一個graph;
(2)使用session去執行graph中的operation。
當使用tensorflow進行graph構建時,大體可以分為五部分:
1、為輸入X與輸出y定義placeholder;
2、定義權重W;
3、定義模型結構;
4、定義損失函式;
5、定義優化演算法
下面是手寫識別字程式:
-
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) #匯入資料集 x = tf.placeholder(shape=[None,784],dtype=tf.float32) y = tf.placeholder(shape=[None,10],dtype=tf.float32) #為輸入輸出定義placehloder w = tf.Variable(tf.truncated_normal(shape=[784,10],mean=0,stddev=0.5)) b = tf.Variable(tf.zeros([10])) #定義權重 y_pred = tf.nn.softmax(tf.matmul(x,w)+b) #定義模型結構 loss =tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred),reduction_indices=[1])) #定義損失函式 opt = tf.train.GradientDescentOptimizer(0.05).minimize(loss) #定義優化演算法 sess =tf.Session() sess.run(tf.global_variables_initializer()) for each in range(1000): batch_xs,batch_ys = mnist.train.next_batch(100) loss1 = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys}) opt1 = sess.run(opt,feed_dict={x:batch_xs,y:batch_ys}) print(loss1)