1. 程式人生 > >Tensorflow(三)- CNN_MNIST

Tensorflow(三)- CNN_MNIST

本文模型為tensorflow官方文件中,構建多層CNN完成MNIST訓練的模型,通過復現模型來講解用到的API以及對CNN進行一個簡單的梳理。

權重初始化

tf.truncated_normal() 截斷正態分佈API
https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/truncated_normal
首先使用正態分佈,加入高斯噪聲,打破對稱性。然後使用截斷正態分佈是為了限制權重大小,不會太大。
所謂截斷正態分佈,就是隻在我截斷的區域裡面取值,官方文件說是將取值限制在離均值2倍標準差的範圍內。一旦超過這個範圍,就重新取值。

def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

卷積與池化

卷積
tf.nn.conv2d(input,filter,strides,padding,use_cudnn_on_gpu=None,data_format=None
,name=None)

二維卷積API
https://www.tensorflow.org/api_docs/python/tf/nn/conv2d
首先講述data_format,只有在規範了資料格式的情況下,才能更好的進行討論。
data_format是該API的一個帶有預設引數的輸入,它的意義在於規範卷積的輸入輸出的資料格式,一般不需要定義。預設為”NHWC”,即[batch, height, width, channels]。
input自然是我們的輸入,對於第一層,就是我們的圖片,格式為[mini_batch_size, 圖片高度,圖片寬度,圖片通道數]
filter卷積核,它的格式是[filter_height, filter_width, in_channels, out_channels]。
strides

步長,是配合卷積核在圖片上進行卷積時,移動的步長,它的資料格式對應於input,對於二維卷積來說,strides[0] = strides[3] = 1, 在batch和channel方向上不存在跳步移動的。那麼strides[1]就代表卷積核在圖片高度上的移動,strides[2]就代表卷積核在圖片寬度上的移動。
padding填充,有兩種padding方式,’VALID’表示不進行padding,’SAME’表示進行zero padding。這兩種padding方式要仔細講一下,因為它們決定了我們輸出的feature map的高度和寬度。

  • VALID
    不進行padding,卷積核在圖片上通過strides進行滑動直到卷積核超出圖片,卷積結束,無法卷積到的部分會被丟棄。那麼輸出的影象大小是多少呢,我們來計算一下。(我們只計算寬度方向,高度方向相似)
    filter_width + n × stride in_width < filter_width + ( n + 1 ) × stride in_width filter_width stride < n + 1 in_width filter_width + stride stride
    上述公式中 n 表示卷積核滑動的次數,那麼輸出的寬度就等於 n+1,因為本身卷積核初始所在位置就已經算一個維度了。所以valid padding下,輸出的寬度為:
    為了統一使用ceil,和上面公式對比會發現多加了一個1,這是為了避免ceil取整遇到整數情況,導致取錯方向,因為我們的步長是整數且至少是1,所以加1並不影響其他結果而且還修正了(in_width-filter_width)/strides[2]為整數的情況。本身原理沒有問題。
out_width = ceil((in_width-filter_width+1)/strides[2])
  • SAME
    進行zero padding,即當圖片不夠卷積核進行移動時,對圖片進行補0,直到卷積核整個移出圖片,卷積結束。那麼讓我們來計算一下輸出的特徵圖大小,同樣只計算寬度方向。
    n × stride < in_width ( n + 1 ) × stride in_width stride n + 1 < in_width + stride stride
    同樣 n+1 表示輸出的寬度,
out_width = ceil(in_width/strides[2])

對於SAME來說,還有一點需要注意,那就是它是怎麼補零的,採用左右,上下兩邊平均補零,如果需要補零的個數為奇數,那多出來的那個零補在右邊和下邊。

下面為模型定義的卷積函式,可以看到strides步長均為1,即在影象高度,寬度上都按1進行移動,採用SAME進行padding,結合這個strides和padding方式,我們可以算出,輸出圖片大小等於輸入圖片大小

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
池化
max_pool(value,ksize,strides,padding,data_format='NHWC',name=None)

最大值池化 API
https://www.tensorflow.org/api_docs/python/tf/nn/max_pool
對於池化來講,池化的data_format, strides以及padding都和上面講的卷積的API完全相同。最大值池化實際上就是對相應的區域選取其中的最大值來代表這整個區域,這個區域大小由ksize決定,這個ksize的資料格式也是同data_format一樣,那麼也就是[batch, height, width, channel]。對於二維卷積來說,ksize[0] = ksize[3] = 1,理由與strides相同,ksize[1]和ksize[2]就共同決定了這個區域大小。同樣地,我們來計算一下進行pool輸出的影象大小(只計算寬度方向)

  • VALID

    k_width + n × stride in_width < k_width + ( n + 1 ) × stride in_width k_width stride < n + 1 in_width k_width + stride stride

  • SAME

    n × stride < in_width ( n + 1 ) × stride in_width stride n + 1 < in_width + stride stride

對於下面模型定義的max_pool來說,ksize[1] = ksize[2] = 2,那麼就相當於對於輸入圖片的 2 × 2 的區域用其中最大的數值來代表,那就相當於寬度上縮小了2倍,高度上縮小了2倍。

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

兩層卷積

先上程式碼

    W_conv1 = weight_variable([5,5,1,32])
    b_conv1 = bias_variable([32])
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    W_conv2 = weight_variable([5,5,32,64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

直接講述太麻煩,做了個表格,貼出圖,更加直觀。
1. input中的n,也就是batch,在實際中我們一般會採用mini_batch,這個n的大小就等於mini_batch_size.
2. 關於c,也就是channel,MNIST資料集,黑白影象只有一個通道,所以是1。別的影象可能會有多個渠道,比如RGB就是三通道。
3. 對於引數個數計算,其實很簡單,只需要關注kernel就行了,因為引數實際就是kernel的權重。
4. 對於連線數的計算,關於輸出和引數個數即可,本質上每張輸出的特徵圖上的每個點,都是通過kernel與對應大小的圖片patch相連。
5. 關於池化層的引數個數和連線數個數,博主還沒有參透,不敢妄言,一種說法是沒有引數,最大值pooling就取最大值,平均值pooling就取平均值,還有一種說法是帶有一個bias,即每個特徵圖一個bias。具體哪種是對的,等我參透,再來修改。
這裡寫圖片描述

全連線層帶dropout加輸出層

這個部分就沒什麼好講的了。
tf.nn.dropout() dropout API
https://www.tensorflow.org/api_docs/python/tf/nn/dropout
tensorflow dropout 自帶rescale,也就是我們前面提到過的’inverted dropout’, 這樣就不要在test的時候進行rescale。

W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)    
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

訓練加評估

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(20000):
  batch = mnist.train.next_batch(50)
  if i%100 == 0:
    train_accuracy = accuracy.eval(feed_dict={
        x:batch[0], y_: batch[1], keep_prob: 1.0})
    print "step %d, training accuracy %g"%(i, train_accuracy)
  train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

print "test accuracy %g"%accuracy.eval(feed_dict={
    x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

小tips

博主自己在復現的時候,出現了一些bug,來與大家分享也是希望大家注意。
1. 對輸入的影象進行均值歸一化,那對於MNIST資料集,我們除以255就可以了,如果不除會有什麼後果呢,博主一開始就忘了除,導致輸出層softmax接收到的數值太大,可想而知再求個e指數,數值肯定爆炸,最後cost輸出為nan。
2. 注意data type,比如tf.argmax預設輸出資料型別tf.int64,tf.equal()在進行兩者比較時,資料型別不一樣,自然是不能比較的。
3. 注意資料維度,到底是最終是[batch, data],還是[data, batch]。tf.argmax()預設是按列取最大值下標,如果想要按行,要麼tf.transpose(),要麼tf.argmax( _ , 1)。
4. 注意tf.equal()出來的是bool矩陣,需要利用tf.cast()將其進行型別轉換。
5. 最後注意一下一定要使用mini_batch,而且注意大小,博主剛從多層神經網路出來,上來就試了一個1024,直接把電腦跑來卡住。最後硬關機重啟很心疼啊。

注:本篇部落格程式碼全部來自於tensorflow官方文件,並不是自己的復現程式碼,當然自己的復現程式碼也就稍作了一點修改而已。


小廣告

淘寶choker、耳飾小店 物理禁止
女程式設計師編碼時和編碼之餘 都需要一些美美的choker、耳飾來裝扮自己
男程式設計師更是需要常備一些來送給自己心儀的人
淘寶小店開店不易 希望有緣人多多支援 (O ^ ~ ^ O)
本號是本人 只是發則小廣告 沒有被盜 會持續更新深度學習相關博文和一些翻譯
感謝大家 不要拉黑我 ⊙﹏⊙|||°
這裡寫圖片描述