CNN實現mnist資料集數字識別
阿新 • • 發佈:2018-12-14
# -*- coding: utf-8 -*- # @Time : 2018/12/14 13:07 # @Author : WenZhao # @Email : [email protected] # @File : mnistCnn-1.py # @Software: PyCharm ''' CNN實現mnist資料集數字識別 卷積神經網路 1.卷積層 : conv2d 2.非線性變換層: tf.nn.relu/sigmiod/tanh(啟用函式) 3.池化層 : tf.nn.pool/tf.nn.avg 4.全連線層 : w*x+b ''' import tensorflow as tf # 下載資料集 from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("./data/MNIST_data/",one_hot=True) x=tf.placeholder("float",shape=[None,784]) y_=tf.placeholder("float",shape=[None,10]) x_image=tf.reshape(x,[-1,28,28,1]) # tf.contrib.layers.convolution2d完成了卷積和啟用兩步 conv2d_1=tf.contrib.layers.convolution2d( x_image, num_outputs=32, kernel_size=(5,5,), activation_fn=tf.nn.relu, stride=(1,1), padding='SAME', trainable=True ) # 池化層 pool_1=tf.nn.max_pool(conv2d_1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') conv2d_2=tf.contrib.layers.convolution2d( pool_1, num_outputs=64, kernel_size=(5,5,), activation_fn=tf.nn.relu, stride=(1,1), padding='SAME', trainable=True ) pool_2=tf.nn.max_pool(conv2d_2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') # 扁平化 pool2_flat=tf.reshape(pool_2,[-1,7*7*64]) # 全連線運算 fc_1=tf.contrib.layers.fully_connected(pool2_flat,1024,activation_fn=tf.nn.relu) # dropout層:隨機去掉一些單元,增加擬合性 keep_prob=tf.placeholder("float") fc1_drop=tf.nn.dropout(fc_1,keep_prob) fc_2=tf.contrib.layers.fully_connected(fc1_drop,10,activation_fn=tf.nn.softmax) loss=-tf.reduce_sum(y_*tf.log(fc_2)) train_step=tf.train.GradientDescentOptimizer(0.0001).minimize(loss) sess=tf.Session() sess.run(tf.global_variables_initializer()) for i in range(20000): batch=mnist.train.next_batch(50) sess.run(train_step,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5}) if i%100==0: print(sess.run(loss,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})) # 計算準確率 correct_prediction=tf.equal(tf.argmax(fc_2,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float")) # 測試集準確率 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1}) print(acc)