1. 程式人生 > >使用pytorch實現Inception模組

使用pytorch實現Inception模組

在pytorch中沒有找到Inception模組,自己寫了一個,以供呼叫。
Inception模組的順序為:
1. 輸入 -> 1*1卷積 -> BatchNorm -> ReLU -> 1*5卷積 -> BatchNorm -> ReLU
2. 輸入 -> 1*1卷積 -> BatchNorm -> ReLU -> 1*3卷積 -> BatchNorm -> ReLU
3. 輸入 -> 池化 -> 1*1卷積 -> BatchNorm -> ReLU
4. 輸入 -> 1*1卷積 -> BatchNorm -> ReLU
其中,1和2步驟可以重複多次。最後將所有結果串接起來。
pytorch中實現如下,應用例子見我的下一篇文章:openface(三):卷積網路。

import torch.nn as nn
class Inception(nn.Module):
    def __init__(self, inputSize, kernelSize, kernelStride, outputSize, reduceSize, pool):
         # inputSize:輸入尺寸
         # kernelSize:第1步驟和第2步驟中第二個卷積核的尺寸,是一個列表
         # kernelStride:同上
         # outputSize:同上
         # reduceSize:1*1卷積中的輸出尺寸,是一個列表
# pool: 是一個池化層 super(Inception, self).__init__() self.layers = {} poolFlag = True fname = 0 for p in kernelSize, kernelStride, outputSize, reduceSize: if len(p) == 4: (_kernel, _stride, _output, _reduce) = p self.layers[str(fname)] = nn.Sequential( # Convolution 1*1
nn.Conv2d(inputSize, _reduce, 1), nn.BatchNorm2d(_reduce), nn.ReLU(), # Convolution kernel*kernel nn.Conv2d(_reduce, _output, _kernel, _stride), nn.BatchNorm2d(_output), nn.ReLU()) else: if poolFlag: assert len(p) == 1 self.layers[str(fname)] = nn.Sequential( # pool pool, #這裡的輸出尺寸需要考慮一下 nn.Conv2d(inputSize, p, 1), nn.BatchNorm2d(p), nn.ReLU()) poolFlag = False else: assert len(p) == 1 self.layers[str(fname)] = nn.Sequential( # Convolution 1*1 nn.Conv2d(inputSize, p, 1), nn.BatchNorm2d(p), nn.ReLU()) fname += 1 if poolFlag: self.layers[str(fname)] = nn.Sequential(pool) poolFlag = False def forward(self, x): for key, layer in self.layers.items: if key == str(0): out = layer(x) else: out = torch.cat((out, layer(x)), 1) #因為有Batch,所以是在第1維方向串接。 return out