pytroch用自定義的tensor初始化nn.sequential中linear或者conv層的一種簡單方法。
阿新 • • 發佈:2018-11-21
話不多說,上程式碼,上面寫的很清楚。
import torch.nn as nn import torch net= nn.Sequential( nn.Linear(1024, 512), nn.ReLU(inplace=True), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Linear(256, 6), ) net[4].weight.data=torch.zeros(6,256) net[4].bias.data=torch.ones(6) t=torch.randn(32,1024) print(net(t).size()) cnn=nn.Sequential( nn.Conv2d(2,8,3,1,1), nn.Conv2d(8,19,3,1,1) ) cnn[0].weight.data=torch.randn(8*2*3*3).view(8,2,3,3) cnn[0].bias.data=torch.ones(8) t=torch.randn(32,2,100,100) print(cnn(t).size())
注:
- 注意線性層和卷積層輸入通道和輸出通道的關係,初始化的時候要是轉置的形式。
- 後面生成一個tensor送入網路是為了測試初始化的正確性
- 出錯的話請參考:The expanded size of the tensor (256) must match the existing size (81) at non-singleton dimension1
- 雖然程式碼中寫的是sequential,對於module同樣是可以用的