PyTorch預訓練的實現
前言
最近使用PyTorch感覺妙不可言,有種當初使用Keras的快感,而且速度還不慢。各種設計直接簡潔,方便研究,比tensorflow的臃腫好多了。今天讓我們來談談PyTorch的預訓練,主要是自己寫程式碼的經驗以及論壇PyTorch Forums上的一些回答的總結整理。
直接載入預訓練模型
如果我們使用的模型和原模型完全一樣,那麼我們可以直接載入別人訓練好的模型:
my_resnet = MyResNet(*args,**kwargs) my_resnet.load_state_dict(torch.load("my_resnet.pth"))
當然這樣的載入方法是基於PyTorch推薦的儲存模型的方法:
torch.save(my_resnet.state_dict(),"my_resnet.pth")
還有第二種載入方法:
my_resnet = torch.load("my_resnet.pth")
載入部分預訓練模型
其實大多數時候我們需要根據我們的任務調節我們的模型,所以很難保證模型和公開的模型完全一樣,但是預訓練模型的引數確實有助於提高訓練的準確率,為了結合二者的優點,就需要我們載入部分預訓練模型。
pretrained_dict = model_zoo.load_url(model_urls['resnet152']) model_dict = model.state_dict() # 將pretrained_dict裡不屬於model_dict的鍵剔除掉 pretrained_dict = {k: v for k,v in pretrained_dict.items() if k in model_dict} # 更新現有的model_dict model_dict.update(pretrained_dict) # 載入我們真正需要的state_dict model.load_state_dict(model_dict)
因為需要剔除原模型中不匹配的鍵,也就是層的名字,所以我們的新模型改變了的層需要和原模型對應層的名字不一樣,比如:resnet最後一層的名字是fc(PyTorch中),那麼我們修改過的resnet的最後一層就不能取這個名字,可以叫fc_
微改基礎模型預訓練
對於改動比較大的模型,我們可能需要自己實現一下再載入別人的預訓練引數。但是,對於一些基本模型PyTorch中已經有了,而且我只想進行一些小的改動那麼怎麼辦呢?難道我又去實現一遍嗎?當然不是。
我們首先看看怎麼進行微改模型。
微改基礎模型
PyTorch中的torchvision裡已經有很多常用的模型了,可以直接呼叫:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
import torchvision.models as models resnet18 = models.resnet18() alexnet = models.alexnet() squeezenet = models.squeezenet1_0() densenet = models.densenet_161()
但是對於我們的任務而言有些層並不是直接能用,需要我們微微改一下,比如,resnet最後的全連線層是分1000類,而我們只有21類;又比如,resnet第一層卷積接收的通道是3, 我們可能輸入圖片的通道是4,那麼可以通過以下方法修改:
resnet.conv1 = nn.Conv2d(4,64,kernel_size=7,stride=2,padding=3,bias=False) resnet.fc = nn.Linear(2048,21)
簡單預訓練
模型已經改完了,接下來我們就進行簡單預訓練吧。
我們先從torchvision中呼叫基本模型,載入預訓練模型,然後,重點來了,將其中的層直接替換為我們需要的層即可:
resnet = torchvision.models.resnet152(pretrained=True) # 原本為1000類,改為10類 resnet.fc = torch.nn.Linear(2048,10)
其中使用了pretrained引數,會直接載入預訓練模型,內部實現和前文提到的載入預訓練的方法一樣。因為是先載入的預訓練引數,相當於模型中已經有引數了,所以替換掉最後一層即可。OK!
以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支援我們。