【小白學PyTorch】4 構建模型三要素與權重初始化
阿新 • • 發佈:2020-09-03
文章目錄:
[TOC]
## 1 模型三要素
三要素其實很簡單
1. 必須要繼承nn.Module這個類,要讓PyTorch知道這個類是一個Module
2. 在\_\_init\_\_(self)中設定好需要的元件,比如conv,pooling,Linear,BatchNorm等等
3. 最後在forward(self,x)中用定義好的元件進行組裝,就像搭積木,把網路結構搭建出來,這樣一個模型就定義好了
我們來看一個例子:
先看\_\_init\_\_(self)函式
```python
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.pool2 = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
```
第一行是初始化,往後定義了一系列元件。```nn.Conv2d```就是一般圖片處理的卷積模組,然後池化層,全連線層等等。
定義完這些定義forward函式
```python
def forward(self,x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
```
x為模型的輸入,第一行表示x經過conv1,然後經過啟用函式relu,然後經過pool1操作
第三行表示對x進行reshape,為後面的全連線層做準備
至此,對一個模型的定義完畢,如何使用呢?
例如:
```python
net = Net()
outputs = net(inputs)
```
其實```net(inputs)```,就是類似於使用了```net.forward(inputs)```這個函式。
## 2 引數初始化
簡單地說就是**設定什麼層用什麼初始方法,初始化的方法會在torch.nn.init中**
話不多說,看一個案例:
```python
# 定義權值初始化
def initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,nn.Linear):
torch.nn.init.normal_(m.weight.data,0,0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()
```
這段程式碼的基本流程就是,先從self.modules()中遍歷每一層,然後判斷更曾屬於什麼型別,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然後根據不同型別的層,設定不同的權值初始化方法,例如Xavier,kaiming,normal_等等。**kaiming也是MSRA初始化,是何愷明大佬在微軟亞洲研究院的時候,因此得名。**
上面程式碼中用到了```self.modules()```,這個是什麼東西呢?
```python
# self.modules的原始碼
def modules(self):
for name,module in self.named_modules():
yield module
```
功能就是:能依次返回模型中的各層,```yield```是讓一個函式可以像迭代器一樣可以用for迴圈不斷從裡面遍歷(可能說的不太明確)。
## 3 完整執行程式碼
我們用下面的例子來更深入的理解```self.modules()```,同時也把上面的內容都串起來(下面的程式碼塊可以執行):
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()
net = Net()
net.initialize_weights()
print(net.modules())
for m in net.modules():
print(m)
```
執行結果:
```python
# 這個是print(net.modules())