1. 程式人生 > >使用tensorflow預測函式的引數值(a simple task)

使用tensorflow預測函式的引數值(a simple task)

已知x1,x2,x3,y
根據y = aax1 + bbx2 + abx3
預測引數a和b

import tensorflow as tf
import numpy as np

x1_data = np.random.rand(100).astype(np.float32)
x2_data = np.random.rand(100).astype(np.float32)
x3_data = np.random.rand(100).astype(np.float32)
y_data = x1_data*0.5*0.5 + x2_data*0.8*0.8 + x3_data*0.5*0.8

a = tf.Variable(tf.random_uniform([1]))
b = tf.Variable(tf.random_uniform([1]))

y = a*a*x1_data + b*b*x2_data + a*b*x3_data

loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.AdamOptimizer(0.1)
train = optimizer.minimize(loss)

init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    for step in range(202):
        sess.run(train)
        if step % 20 ==0:
            print(step,sess.run(a),sess.run(b),sess.run(loss))

Output:
0 [0.16722116] [0.4231345] 0.287351
20 [0.46662363] [0.81065917] 0.00041220474
40 [0.51124895] [0.83035856] 0.0021522618
60 [0.50623244] [0.79908705] 2.3531691e-05
80 [0.49737427] [0.79750854] 2.718182e-05
100 [0.50043434] [0.80122584] 3.336514e-06
120 [0.5001132] [0.799731] 5.6176066e-08
140 [0.4998217] [0.7999378] 5.3850208e-08
160 [0.50005966] [0.8000353] 8.847866e-09
180 [0.50001425] [0.8000164] 1.0047848e-09
200 [0.49999878] [0.79999965] 2.2931767e-12