1. 程式人生 > 實用技巧 >PyTorch的nn.Linear()詳解

PyTorch的nn.Linear()詳解

1. nn.Linear()

  • nn.Linear():用於設定網路中的全連線層,需要注意的是全連線層的輸入與輸出都是二維張量

  • 一般形狀為[batch_size, size],不同於卷積層要求輸入輸出是四維張量。其用法與形參說明如下:

  • in_features指的是輸入的二維張量的大小,即輸入的[batch_size, size]中的size。

  • out_features指的是輸出的二維張量的大小,即輸出的二維張量的形狀為[batch_size,output_size],當然,它也代表了該全連線層的神經元個數。

  • 從輸入輸出的張量的shape角度來理解,相當於一個輸入為[batch_size, in_features]

    的張量變換成了[batch_size, out_features]的輸出張量。

用法示例:

import torch as t
from torch import nn
from torch.nn import functional as F

# 假定輸入的影象形狀為[3,64,64]
x = t.randn(10, 3, 64, 64)      # 10張 3個channel 大小為64x64的圖片

x = nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0)(x)
print(x.shape)


# 之前的特徵圖尺寸為多少,只要設定為(1,1),那麼最終特徵圖大小都為(1,1) 
# x = F.adaptive_avg_pool2d(x, [1,1])    # [b, 64, h, w] => [b, 64, 1, 1]
# print(x.shape)

# 將四維張量轉換為二維張量之後,才能作為全連線層的輸入
x = x.view(x.size(0), -1)
print(x.shape)

# in_features由輸入張量的形狀決定,out_features則決定了輸出張量的形狀 
connected_layer = nn.Linear(in_features = 64*21*21, out_features = 10)

# 呼叫全連線層
output = connected_layer(x) 
print(output.shape)
torch.Size([10, 64, 21, 21])
torch.Size([10, 28224])
torch.Size([10, 10])