1. 程式人生 > 其它 >卷積 - 用pytorch計算

卷積 - 用pytorch計算

對於一張圖片,每個卷積核的通道數都和圖片通道數一樣,用n個卷積核進行卷積得到的結果就是一個n通道的特徵圖。

卷積的步長(stride)和padding決定了產生的特徵圖的size

對於一張7*7*3的圖片(長7,寬7,3通道),使用兩個3*3*3的卷積核(長3,寬3,3通道)進行卷積,如下圖:

最左邊一列是原始圖片的3個通道的資料,中間兩列紅色的是2個卷積核。最右邊一列是卷積得到的feature map,2通道。
卷積核的3通道分別和影象的3個通道進行element-wise的相乘,結果再相加。
比如,上圖綠色框中的3的計算過程就是每行的紅框和對應的藍框做卷積得到3個數,分別是1,1,0. 然後加上下面的Bias,結果就是3.

pytorch算上圖卷積

import torch
import torch.nn as nn

# 輸入是4維(1 * 3 * 5 * 5)
# 1張圖,3通道,高5,寬5
x = [[
    [
        [0,1,1,0,2],
        [0,2,2,1,1],
        [0,2,1,1,2],
        [1,0,2,1,2],
        [0,1,1,1,1]
    ],
    [
        [2,2,1,0,0],
        [2,1,0,0,1],
        [2,2,1,2,2],
        [1,2,0,1,2],
        [1,2,0,2,1]
    ],
    [
        [0,0,1,0,2],
        [1,1,1,1,1],
        [2,0,2,1,1],
        [1,1,1,0,0],
        [0,0,0,1,0]
    ]
]]

w0  =  [[[-1,  0,  1],
         [-1,  0,  1],
         [ 0,  1,  0]],

        [[ 1, -1,  1],
         [ 1,  1,  0],
         [-1,  0, -1]],

        [[ 0,  1,  0],
         [-1,  0,  1],
         [-1,  0,  0]]]

w1  =  [[[-1,  1,  0],
         [ 1,  0, -1],
         [-1,  0,  1]],

        [[ 0,  0,  1],
         [ 0,  1,  1],
         [ 1,  1, -1]],

        [[ 0,  1,  0],
         [-1, -1,  1],
         [ 1, -1, -1]]]
bias = [1,0]
# 轉換成float型別的Tensor,後面會自動算梯度,要用float。
x = torch.tensor(x).float()
w = nn.Conv2d(3,2,3,padding=1,stride=2) # 建立卷積函式
w.weight.data = torch.tensor([w0,w1]).float() # 使用自定義卷積核,不然pytorch預設會自己生成
w.bias.data = torch.tensor(bias).float()
output = w(x)
'''
不想轉成float的話,就關閉自動計算梯度,也省視訊記憶體
x = torch.tensor(x)
w = nn.Conv2d(3,2,3,padding=1,stride=2)
w.weight.data = torch.tensor([w0,w1])
w.bias.data = torch.tensor(bias)
with torch.no_grad():
    output = w(x)
'''
print(output)

輸出如下:

tensor([[[[ 3.,  3.,  1.],
          [ 6.,  3.,  3.],
          [ 5.,  9.,  0.]],

         [[ 4.,  0., -2.],
          [-1.,  6.,  4.],
          [ 6.,  7.,  2.]]]], grad_fn=<ThnnConv2DBackward>)

關閉梯度計算後的輸出:

tensor([[[[ 3,  3,  1],
          [ 6,  3,  3],
          [ 5,  9,  0]],

         [[ 4,  0, -2],
          [-1,  6,  4],
          [ 6,  7,  2]]]])