tensorflow 自己手動實現的線性迴歸
阿新 • • 發佈:2018-12-23
tensorflow 自己手動實現的線性迴歸
#!/usr/bin/python # -*- coding:utf-8 -*- import tensorflow as tf import os # 第二個引數是預設值 tf.app.flags.DEFINE_integer("max_iter", 100, "迭代次數") tf.app.flags.DEFINE_string("model_dir", "./tmp/ckpt/model", "模型路徑") tf.app.flags.DEFINE_string("summary_dir", "./tmp/test/", "graph路徑") tf.app.flags.DEFINE_string("checkpoint_dir", "./tmp/ckpt/checkpoint", "模型路徑") FLAGS=tf.app.flags.FLAGS def mylineregression(): with tf.variable_scope("data"): x=tf.random_normal([100,1],0.0,1.0) y=tf.multiply(x,[[0.7]])+0.8 with tf.variable_scope("model"): weight=tf.Variable(tf.random_normal([1,1],0.0,1.0)) bias=tf.Variable(0.0) y_predict=tf.multiply(x,weight)+bias with tf.variable_scope("loss"): loss=tf.reduce_mean(tf.square(y-y_predict)) with tf.variable_scope("optimizer"): train_op=tf.train.GradientDescentOptimizer(0.1).minimize(loss) init_value=tf.global_variables_initializer() saver=tf.train.Saver() tf.summary.scalar("losses",loss) tf.summary.histogram("weight",weight) tf.summary.histogram("bias",bias) merged=tf.summary.merge_all() with tf.Session() as sess: sess.run(init_value) filwriter=tf.summary.FileWriter(FLAGS.summary_dir, graph=sess.graph) # print(sess.run([weight,bias])) # 載入模型,覆蓋變數的值 if os.path.exists(FLAGS.checkpoint_dir): saver.restore(sess,FLAGS.model_dir) for i in range(FLAGS.max_iter): print("第%d次訓練引數weight:%f,bias:%f"%(i,weight.eval(),bias.eval())) # print(y_predict.eval()) summary=sess.run(merged) filwriter.add_summary(summary,i) sess.run(train_op) tf.summary.FileWriter(FLAGS.summary_dir,graph=sess.graph) # 儲存模型 # saver.save(sess,"./tmp/ckpt/model") return None if __name__ == '__main__': print("hello") mylineregression()