使用pytorch實現Inception模組
阿新 • • 發佈:2019-02-12
在pytorch中沒有找到Inception模組,自己寫了一個,以供呼叫。
Inception模組的順序為:
1. 輸入 -> 1*1卷積 -> BatchNorm -> ReLU -> 1*5卷積 -> BatchNorm -> ReLU
2. 輸入 -> 1*1卷積 -> BatchNorm -> ReLU -> 1*3卷積 -> BatchNorm -> ReLU
3. 輸入 -> 池化 -> 1*1卷積 -> BatchNorm -> ReLU
4. 輸入 -> 1*1卷積 -> BatchNorm -> ReLU
其中,1和2步驟可以重複多次。最後將所有結果串接起來。
pytorch中實現如下,應用例子見我的下一篇文章:openface(三):卷積網路。
import torch.nn as nn
class Inception(nn.Module):
def __init__(self, inputSize, kernelSize, kernelStride, outputSize, reduceSize, pool):
# inputSize:輸入尺寸
# kernelSize:第1步驟和第2步驟中第二個卷積核的尺寸,是一個列表
# kernelStride:同上
# outputSize:同上
# reduceSize:1*1卷積中的輸出尺寸,是一個列表
# pool: 是一個池化層
super(Inception, self).__init__()
self.layers = {}
poolFlag = True
fname = 0
for p in kernelSize, kernelStride, outputSize, reduceSize:
if len(p) == 4:
(_kernel, _stride, _output, _reduce) = p
self.layers[str(fname)] = nn.Sequential(
# Convolution 1*1
nn.Conv2d(inputSize, _reduce, 1),
nn.BatchNorm2d(_reduce),
nn.ReLU(),
# Convolution kernel*kernel
nn.Conv2d(_reduce, _output, _kernel, _stride),
nn.BatchNorm2d(_output),
nn.ReLU())
else:
if poolFlag:
assert len(p) == 1
self.layers[str(fname)] = nn.Sequential(
# pool
pool, #這裡的輸出尺寸需要考慮一下
nn.Conv2d(inputSize, p, 1),
nn.BatchNorm2d(p),
nn.ReLU())
poolFlag = False
else:
assert len(p) == 1
self.layers[str(fname)] = nn.Sequential(
# Convolution 1*1
nn.Conv2d(inputSize, p, 1),
nn.BatchNorm2d(p),
nn.ReLU())
fname += 1
if poolFlag:
self.layers[str(fname)] = nn.Sequential(pool)
poolFlag = False
def forward(self, x):
for key, layer in self.layers.items:
if key == str(0):
out = layer(x)
else:
out = torch.cat((out, layer(x)), 1) #因為有Batch,所以是在第1維方向串接。
return out