如何列印Pytorch中Sequential包裝的中間層的大小
阿新 • • 發佈:2021-04-23
題目引入
為了方便,我們用Sequential定義一個卷積神經網路中,例如
self.fc = nn.Sequential(
nn.Upsample(scale_factor=2,mode='nearest'),
nn.Conv2d(512,512,3,1,1,bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
…………
那麼我們如何檢視諸如nn.Conv2d(512,512,3,1,1,bias=False)
這一層輸出的大小呢?
方法就是定義一個列印函式並定義一個符合規格的假資料,然後迴圈輸出。
程式碼
def printInfo(self):
x = torch.rand([3,512,1,1])
for name, module in self.fc.named_children():
x = module(x)
if isinstance(module,nn.Upsample):
print("Upsample({}) : {}".format(name,x.shape))
elif isinstance(module, nn.Conv2d):
print("Conv2d({}) : {}".format(name,x.shape))
這樣,只需要使用model.printInfo()
就可以輸出我們所關心的某幾層的特徵圖規格。
如下圖,第一行諸如Upsample
後面的數字就是nn.Sequential
對應的層號