Pytorch學習筆記之模型的建立
阿新 • • 發佈:2021-01-25
Pytorch中所有模型都是基於Module這個類,也就是說無論是自定義的模型,還是Pytorch中已有的模型,都是這個類的子類,並重寫了forward方法。Pytorch中建立模型有幾種方法。
繼承Module
這是最直接的方法,自己寫一個模型繼承Module,並重寫forward方法。
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
class LinearMoudule(Module):
def __init__(self) :
super(LinearMoudule,self).__init__()
self.linear_1 = nn.Linear(10,30)
self.linear_2 = nn.Linear(30,5)
def forward(self,x):
x = self.linear_1(x)
x = F.tanh(x)
x = self.linear_2(x)
x = F.sigmoid(x)
return x
使用Sequential
使用Sequential是一種快速構建模型的方法,只需將需要新增的模型放入其建構函式即可。
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
module = nn.Sequential(nn.Linear(10,30),
nn.Tanh(),
nn.Linear(30,5),
nn.Sigmoid())
另外還可以用OrderedDict來對每一層模型進行命名。
from torch.nn import Module
import torch. nn as nn
import torch.nn.functional as F
from collections import OrderedDict
module = nn.Sequential(OrderedDict(
{'linear_1':nn.Linear(10,30),
'tanh':nn.Tanh(),
'linear_2': nn.Linear(30,5),
'sigmod': nn.Sigmoid()}
))
module = nn.Sequential(OrderedDict(
[('linear_1', nn.Linear(10,30)),
('tanh',nn.Tanh()),
('linear_2', nn.Linear(30,5)),
('sigmod', nn.Sigmoid())
))
#兩種寫法都可以
ModuleList和ModuleDict
這兩個類顧名思義,是分別通過List和Dict兩種容器將模組進行包裝來建立新模型的。並且這兩個類可以通過迭代來訪問。
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F、
class Module_List(nn.Module):
def __init__(self):
super(Module_List, self).__init__()
self.modules = nn.ModuleList([nn.Linear(10, 30),nn.Tanh(),nn.Linear(30,5),nn.Sigmoid()])
def forward(self, x):
for layer in self.modules:
x = layer(x)
return x
class Module_Dict(nn.Module):
def __init__(self):
super(Module_Dict, self).__init__()
self.modules = nn.ModuleDict({'linear_1' : nn.Linear(10,30),
'tanh':nn.Tanh(),
'linear_2': nn.Linear(30,5),
'sigmod': nn.Sigmoid()})
def forward(self, x):
for layer in self.modules:
x = layer(x)
return x
混合使用
前面幾種建立模型的方法可以混合使用,來建立更為複雜的模型。
class Bottle(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size):
super(Bottle, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
def forward(self,x):
x = self.conv(x)
x = self.relu(x)
x = self.pool(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bottle_1 = Bottle(3,6,5)
self.bottle_2 = Bottle(6,16,5)
self.fc = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU())
self.last_fc = nn.Linear(84, 10)
def forward(self,x):
x = self.bottle_1(x)
x = self.bottle_2(x)
x = x.view(-1, 16 * 5 * 5)
x = self.fc(x)
x = self.last_fc(x)
return x