1. 程式人生 > >pytroch用自定義的tensor初始化nn.sequential中linear或者conv層的一種簡單方法。

pytroch用自定義的tensor初始化nn.sequential中linear或者conv層的一種簡單方法。

話不多說,上程式碼,上面寫的很清楚。

import torch.nn as nn
import torch
net= nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(inplace=True),
    nn.Linear(512, 256),
    nn.ReLU(inplace=True),
    nn.Linear(256, 6),
)
net[4].weight.data=torch.zeros(6,256)
net[4].bias.data=torch.ones(6)
t=torch.randn(32,1024)
print(net(t).size())
cnn=nn.Sequential(
    nn.Conv2d(2,8,3,1,1),
    nn.Conv2d(8,19,3,1,1)
)
cnn[0].weight.data=torch.randn(8*2*3*3).view(8,2,3,3)
cnn[0].bias.data=torch.ones(8)
t=torch.randn(32,2,100,100)
print(cnn(t).size())

注:

  1. 注意線性層和卷積層輸入通道和輸出通道的關係,初始化的時候要是轉置的形式。
  2. 後面生成一個tensor送入網路是為了測試初始化的正確性
  3. 出錯的話請參考:The expanded size of the tensor (256) must match the existing size (81) at non-singleton dimension1
  4. 雖然程式碼中寫的是sequential,對於module同樣是可以用的