卷積神經網路分類mnist手寫體數字
阿新 • • 發佈:2018-12-31
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data",one_hot=True) import matplotlib.pyplot as plt class Net: def __init__(self): self.x = tf.placeholder(tf.float32,[None,28,28,1]) self.y = tf.placeholder(tf.float32,[None,10]) self.conv1_w = tf.Variable(tf.random_normal([3,3,1,16],dtype=tf.float32,stddev=0.1)) self.conv1_b = tf.Variable(tf.zeros([16])) self.conv2_w = tf.Variable(tf.random_normal([3,3,16,32],dtype=tf.float32,stddev=0.1)) self.conv2_b = tf.Variable(tf.zeros([32])) self.w1 = tf.Variable(tf.random_normal([7*7*32,128],stddev=0.1)) self.b1 = tf.Variable(tf.zeros([128])) self.w2 = tf.Variable(tf.random_normal([128,10],stddev=0.1)) self.b2 = tf.Variable(tf.zeros([10])) def forward(self): self.conv1 = tf.nn.relu(tf.nn.conv2d(self.x,self.conv1_w,strides=[1,1,1,1],padding='SAME')+self.conv1_b) self.pool1 = tf.nn.max_pool(self.conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') self.conv2 = tf.nn.relu(tf.nn.conv2d(self.pool1,self.conv2_w,strides=[1,1,1,1],padding='SAME')+self.conv2_b) self.pool2 = tf.nn.max_pool(self.conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') self.flat = tf.reshape(self.pool2,[-1,7*7*32]) self.y1 = tf.nn.relu(tf.matmul(self.flat,self.w1)+self.b1) self.y2 = tf.nn.softmax(tf.matmul(self.y1,self.w2)+self.b2) def backward(self): self.loss = tf.reduce_mean((self.y2-self.y)**2) self.opt = tf.train.AdamOptimizer().minimize(self.loss) self.prediction_corect = tf.equal(tf.argmax(self.y2,1),tf.argmax(self.y,1))#比較預測值和真實值是否相等 self.rst = tf.cast(self.prediction_corect,'float')#將布林值轉化為float型別 self.accuracy = tf.reduce_mean(self.rst)#求出平均值表示精度(百分數) if __name__ == '__main__': net = Net() net.forward() net.backward() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) a = [] b = [] c = [] for i in range(1000): a.append(i) x,y = mnist.train.next_batch(100) x = x.reshape([100,28,28,1]) loss,acc,_ = sess.run([net.loss,net.accuracy,net.opt],feed_dict={net.x:x,net.y:y}) b.append(acc) c.append(loss) if i%10 == 0: plt.subplot(1,2,1)#生成1行兩列的子圖顯示在第一個子圖 plt.plot(a,b) plt.title('accuracy rate') plt.subplot(1,2,2)#生成1行兩列的子圖顯示在第二個子圖 plt.plot(a,c) plt.title('loss') plt.pause(0.0001) print(loss,acc)