PyTorch的nn.Linear()詳解
阿新 • • 發佈:2020-07-23
1. nn.Linear()
-
nn.Linear():用於設定網路中的全連線層,需要注意的是全連線層的輸入與輸出都是二維張量
-
一般形狀為[batch_size, size],不同於卷積層要求輸入輸出是四維張量。其用法與形參說明如下:
-
in_features
指的是輸入的二維張量的大小,即輸入的[batch_size, size]
中的size。 -
out_features
指的是輸出的二維張量的大小,即輸出的二維張量的形狀為[batch_size,output_size]
,當然,它也代表了該全連線層的神經元個數。 -
從輸入輸出的張量的shape角度來理解,相當於一個輸入為
[batch_size, in_features]
[batch_size, out_features]
的輸出張量。
用法示例:
import torch as t from torch import nn from torch.nn import functional as F # 假定輸入的影象形狀為[3,64,64] x = t.randn(10, 3, 64, 64) # 10張 3個channel 大小為64x64的圖片 x = nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0)(x) print(x.shape) # 之前的特徵圖尺寸為多少,只要設定為(1,1),那麼最終特徵圖大小都為(1,1) # x = F.adaptive_avg_pool2d(x, [1,1]) # [b, 64, h, w] => [b, 64, 1, 1] # print(x.shape) # 將四維張量轉換為二維張量之後,才能作為全連線層的輸入 x = x.view(x.size(0), -1) print(x.shape) # in_features由輸入張量的形狀決定,out_features則決定了輸出張量的形狀 connected_layer = nn.Linear(in_features = 64*21*21, out_features = 10) # 呼叫全連線層 output = connected_layer(x) print(output.shape)
torch.Size([10, 64, 21, 21])
torch.Size([10, 28224])
torch.Size([10, 10])