全連線層(線性層)-Linear
阿新 • • 發佈:2021-01-27
原理:
#%% import torch import torch.nn as nn #%% 輸入資料 in_features = 3 out_features = 2 input = torch.arange(in_features,dtype=torch.float32).view([1,-1]) #%% 線性層輸出(pytorch) connect1 = nn.Linear(in_features=in_features,out_features=out_features,bias=True) out = connect1(input) print('pytorch輸出:',out) #%% 線性層輸出(公式) #權重 print('weight:',connect1.weight) print('bias:',connect1.bias) #y = xAt+b print('公式輸出:',torch.mm(input,connect1.weight.t())+connect1.bias)
注意:pytorch Linear層引數.weight()為A,需要轉置後和x相乘,再和b相加
pytorch輸出: tensor([[-0.3843, 0.9226]], grad_fn=<AddmmBackward>) weight: tensor([[ 0.4394, -0.3183, 0.1999], [ 0.4067, 0.5720, 0.0915]], requires_grad=True) bias: tensor([-0.4659, 0.1677], requires_grad=True) 公式輸出: tensor([[-0.3843, 0.9226]], grad_fn=<AddBackward0>)