pytorch中Linear類中weight的形狀問題原始碼探討
阿新 • • 發佈:2018-11-01
import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())
print(m.weight.shape)
來看一下輸出:
out:
torch.Size([128, 30])
torch.Size([30, 20])
發現weight的形狀是[30,20]而非[20, 30]?
所以具體看一下原始碼的實現方式:
- Linear類的原始碼網址:https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html
- functional模組的原始碼網址:
https://pytorch.org/docs/stable/_modules/torch/nn/functional.html
- 在Linear類中的
__init__
函式中,weight形狀為[out_features, in_features]
- 在
forward
函式中呼叫F.linear
函式,實現單層線性神經網路層的計算
- 在F.linear函式中,使用的是
weight.t()
,也就是將weight轉置,再傳入matmul計算。
通過以上三步,pytorch就完成weight形狀的維護。簡單的說就是,在定義時使用的是[out_features, in_features],而在單層線性神經網路計算時使用的是weight的轉置矩陣。