pytorch系列 ----暫時就叫5的番外吧: nn.Modlue及nn.Linear 原始碼理解
阿新 • • 發佈:2018-12-14
先看一個列子:
import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
output.size()
out:
torch.Size([128, 30])
剛開始看這份程式碼是有點迷惑的,m是類物件,而直接像函式一樣呼叫m,m(input)
重點:
- nn.Module 是所有神經網路單元(neural network modules)的基類
- pytorch在nn.Module中,實現了
__call__
方法,而在__call__
方法中呼叫了forward函式。
經過以上兩點。上述程式碼就不難理解。
返回的是:
的線性函式
此時再看一下這一份程式碼:
import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
output.size()
首先建立類物件m,然後通過m(input)
實際上呼叫__call__(input)
,然後__call__(input)
呼叫
forward()
函式,最後返回計算結果為:
所以自己建立多層神經網路模組時,只需要在實現__init__
和forward
即可.
接下來看一個簡單的三層神經網路的例子:
# define three layers
class simpleNet(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
以下為各層神經元個數:
輸入: in_dim
第一層: n_hidden_1
第二層:n_hidden_2
第三層(輸出層):out_dim