1. 程式人生 > >tf.nn.max_pool引數含義和用法

tf.nn.max_pool引數含義和用法

max pooling是CNN當中的最大值池化操作,其實用法和卷積很類似

有些地方可以從卷積去參考【TensorFlow】tf.nn.conv2d是怎樣實現卷積的?

tf.nn.max_pool(value, ksize, strides, padding, name=None)
引數是四個,和卷積很類似:

第一個引數value:需要池化的輸入,一般池化層接在卷積層後面,所以輸入通常是feature map,依然是[batch, height, width, channels]這樣的shape

第二個引數ksize:池化視窗的大小,取一個四維向量,一般是[1, height, width, 1],因為我們不想在batch和channels上做池化,所以這兩個維度設為了1

第三個引數strides:和卷積類似,視窗在每一個維度上滑動的步長,一般也是[1, stride,stride, 1]

第四個引數padding:和卷積類似,可以取’VALID’ 或者’SAME’。padding = ‘SAME’ 時,輸出並不一定和原圖size一致,但會保證覆蓋原圖所有畫素,不會捨棄邊上的莫些元素;padding = ‘VALID’ 時,輸出的size總比原圖的size小,有時不會覆蓋原圖所有元素(既,可能捨棄邊上的某些元素).

返回一個Tensor,型別不變,shape仍然是[batch, height, width, channels]這種形式

示例原始碼:

假設有這樣一張圖,雙通道

第一個通道:
在這裡插入圖片描述
第二個通道:
在這裡插入圖片描述
用程式去做最大值池化:

     import tensorflow as tf  
      
    a=tf.constant([  
            [[1.0,2.0,3.0,4.0],  
            [5.0,6.0,7.0,8.0],  
            [8.0,7.0,6.0,5.0],  
            [4.0,3.0,2.0,1.0]],  
            [[4.0,3.0,2.0,1.0],  
             [8.0,7.0,6.0,5.0],  
             [1.0,2.0,3.0,4.0],  
             [5.0,6.0,7.0,8.0]]  
        ])  
      
    a=tf.reshape(a,[1,4,4,2])  
      
    pooling=tf.nn.max_pool(a,[1,2,2,1],[1,1,1,1],padding='VALID')  
    with tf.Session() as sess:  
        print("image:")  
        image=sess.run(a)  
        print (image)  
        print("reslut:")  
        result=sess.run(pooling)  
        print (result)  

這裡步長為1,視窗大小2×2,輸出結果:

     image:  
    [[[[ 1.  2.]  
       [ 3.  4.]  
       [ 5.  6.]  
       [ 7.  8.]]  
      
      [[ 8.  7.]  
       [ 6.  5.]  
       [ 4.  3.]  
       [ 2.  1.]]  
      
      [[ 4.  3.]  
       [ 2.  1.]  
       [ 8.  7.]  
       [ 6.  5.]]  
      
      [[ 1.  2.]  
       [ 3.  4.]  
       [ 5.  6.]  
       [ 7.  8.]]]]  
    reslut:  
    [[[[ 8.  7.]  
       [ 6.  6.]  
       [ 7.  8.]]  
      
      [[ 8.  7.]  
       [ 8.  7.]  
       [ 8.  7.]]  
      
      [[ 4.  4.]  
       [ 8.  7.]  
       [ 8.  8.]]]]  

池化後的圖就是:
在這裡插入圖片描述

在這裡插入圖片描述

證明了程式的結果是正確的。

我們還可以改變步長

     pooling=tf.nn.max_pool(a,[1,2,2,1],[1,2,2,1],padding='VALID')  

最後的result就變成:

     reslut:  
    [[[[ 8.  7.]  
       [ 7.  8.]]  
      
      [[ 4.  4.]  
       [ 8.  8.]]]]