1. 程式人生 > 其它 >全連線層(線性層)-Linear

全連線層(線性層)-Linear

技術標籤:影象處理深度學習

原理:

#%%

import torch
import torch.nn as nn

#%% 輸入資料
in_features = 3
out_features = 2
input = torch.arange(in_features,dtype=torch.float32).view([1,-1])

#%% 線性層輸出(pytorch)
connect1 = nn.Linear(in_features=in_features,out_features=out_features,bias=True)
out = connect1(input)
print('pytorch輸出:',out)

#%% 線性層輸出(公式)
#權重
print('weight:',connect1.weight)
print('bias:',connect1.bias)

#y = xAt+b
print('公式輸出:',torch.mm(input,connect1.weight.t())+connect1.bias)

注意:pytorch Linear層引數.weight()為A,需要轉置後和x相乘,再和b相加

pytorch輸出: tensor([[-0.3843,  0.9226]], grad_fn=<AddmmBackward>)

weight: 
tensor([[ 0.4394, -0.3183,  0.1999],
        [ 0.4067,  0.5720,  0.0915]], requires_grad=True)
bias: 
tensor([-0.4659,  0.1677], requires_grad=True)

公式輸出: tensor([[-0.3843,  0.9226]], grad_fn=<AddBackward0>)