1. 程式人生 > >Pytorch載入部分預訓練模型的引數

Pytorch載入部分預訓練模型的引數

前言

自從從深度學習框架caffe轉到Pytorch之後,感覺Pytorch的優點妙不可言,各種設計簡潔,方便研究網路結構修改,容易上手,比TensorFlow的臃腫好多了。對於深度學習的初學者,Pytorch值得推薦。今天主要主要談談Pytorch是如何載入預訓練模型的引數以及程式碼的實現過程。

直接載入預選臉模型

如果我們使用的模型和預訓練模型完全一樣,那麼我們就可以直接載入別人的模型,還有一種情況,我們在訓練自己模型的過程中,突然中斷了,但只要我們儲存了之前的模型的引數也可以使用下面的程式碼直接載入我們儲存的模型繼續訓練,不用從頭開始。

model=DPN(*args, **kwargs)
model.load_state_dict(torch.load("DPN.pth"))

這樣的載入方式是基於Pytorch使用的模型儲存方法:

torch.save(DPN.state_dict(), "DPN.pth")
載入部分預訓練模型引數

其實大多數時候我們根據自己的任物所提出的模型是在一些公開模型的基礎上改變而來,其中公開模型的引數我們沒有必要在從頭開始訓練,只要載入其訓練好的模型引數即可,這樣有助於提高訓練的準確率和我們模型的泛化能力。

 model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)
 http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}
 pretrained_dict=model_zoo.load_url(http['url'])
 model_dict = model.state_dict()
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys 
 model_dict.update(pretrained_dict)
 model.load_state_dict(model_dict)
 model = torch.nn.DataParallel(model).cuda()

因為需要刪除預訓練模型中不匹配的的鍵,也就是層的名字。

相關推薦

Pytorch載入部分訓練模型引數

前言自從從深度學習框架caffe轉到Pytorch之後,感覺Pytorch的優點妙不可言,各種設計簡潔,方便研究網路結構修改,容易上手,比TensorFlow的臃腫好多了。對於深度學習的初學者,Pytorch值得推薦。今天主要主要談談Pytorch是如何載入預訓練模型的引數以

pytorch 如何載入部分訓練模型

分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow 也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!        

載入resNet訓練模型

# Assume input range is [0, 1] class ResNet101FeatureExtractor(nn.Module): def __init__(self, use_input_norm=True, device=torch.device('cpu'))

【tf.keras】tf.keras載入AlexNet訓練模型

目錄 從 PyTorch 中匯出模型引數 第 0 步:配置環境 第 1 步:安裝 MMdnn 第 2 步:得到 PyTorch 儲存完整結構和引數的模型(pth 檔案) 第 3 步:匯出 Py

Pytorch 快速入門(七)載入訓練模型初始化網路引數

在預訓練網路的基礎上,修改部分層得到自己的網路,通常我們需要解決的問題包括: 1. 從預訓練的模型載入引數 2. 對新網路兩部分設定不同的學習率,主要訓練自己新增的層 PyTorch提供的預訓練模型PyTorch定義了幾個常用模型,並且提供了預訓練版本:AlexNet: Al

PyTorch中使用訓練模型初始化網路的一部分引數(增減網路層,修改某層引數等) 固定引數

在預訓練網路的基礎上,修改部分層得到自己的網路,通常我們需要解決的問題包括: 1. 從預訓練的模型載入引數  2. 對新網路兩部分設定不同的學習率,主要訓練自己新增的層  一. 載入引數的方法:  載入引數可以參考apaszke推薦的做法,即刪除與當前mo

pytorch學習筆記之載入訓練模型

原文:https://blog.csdn.net/weixin_41278720/article/details/80759933  pytorch自發布以來,由於其便捷性,贏得了越來越多人的喜愛。 Pytorch有很多方便易用的包,今天要談的是torchvision包,

PyTorch學習系列(十五)——如何載入訓練模型

PyTorch提供的預訓練模型 PyTorch定義了幾個常用模型,並且提供了預訓練版本: AlexNet: AlexNet variant from the “One weird trick” paper. VGG: VGG-11, VGG-13, VGG

PyTorch-網路的建立,訓練模型載入

本文是PyTorch使用過程中的的一些總結,有以下內容: 構建網路模型的方法 網路層的遍歷 各層引數的遍歷 模型的儲存與載入 從預訓練模型為網路引數賦值 主要涉及到以下函式的使用 add_module,ModulesList,Sequential 模型建立 modules(),named_modules

Pytorch使用訓練模型加速訓練的技巧

當屬於預訓練模型屬於下面的情況的時候,可以採用這個加速的技巧: 固定前部分的層,只改變網路後面層的引數。 比如,使用vgg16的預訓練模型,固定特徵提取層,改變後面的全連線層。要注意的是,如果固定的是特徵提取層+一個全連線層,也可以使用這個技巧,只要固定的是前一部分。

pytorch 更改訓練模型網路結構

一個繼承nn.module的model它包含一個叫做children()的函式,這個函式可以用來提取出model每一層的網路結構,在此基礎上進行修改即可,修改方法如下(去除後兩層): resnet_layer = nn.Sequential(*list(model.children())[:-2])

pytorch 訓練模型修改

# coding=UTF-8 import torchvision.models as models import torch import torch.nn as nn import math import torch.utils.model_zoo as model_zoo class C

PyTorch—torchvision.models匯入訓練模型與殘差網路講解

文章目錄 torchvision.models 1. 模組呼叫 2. 原始碼解析 3. ResNet類 4. Bottlenect類 5. BasicB

pytorch fine-tune 訓練模型

之一: torchvision 中包含了很多預訓練好的模型,這樣就使得 fine-tune 非常容易。本文主要介紹如何 fine-tune torchvision 中預訓練好的模型。 安裝 pip install torchvision 如何 fine-tune 以

Tensorflow載入訓練模型和儲存模型

使用tensorflow過程中,訓練結束後我們需要用到模型檔案。有時候,我們可能也需要用到別人訓練好的模型,並在這個基礎上再次訓練。這時候我們需要掌握如何操作這些模型資料。看完本文,相信你一定會有收穫! 1 Tensorflow模型檔案 我們在checkpo

MXNet學習 (1) :載入訓練模型

首先在MXNet的model zoo下載對應的模型描述檔案以及模型引數檔案: vgg16:對應vgg16.json vgg16-0000.params resnet50:對應resnet50.json resnet50-0000.params

機器學習引數設定與訓練模型設定

使用tensorlayer時,出現了大量相關的引數設定,通用的引數設定如下:task = 'dcgan' flags = tf.app.flags flags.DEFINE_string('task','dcgan','this task name') flags.DEFIN

韓國小哥哥用Pytorch實現谷歌最強NLP訓練模型BERT | 程式碼

乾明 編譯整理自 GitHub 量子位 報道 | 公眾號 QbitAI新鮮程式碼,還熱乎著呢。前

小白程式設計用Pytorch匯入訓練模型&&設定不同學習速率

前兩天正好在做這個部分,參考了很多網友的做法,也去pytorch論壇查了一下,現在總結如下。建議還是自己單步除錯一下看看每個引數裡面的值是什麼樣的比較好。1.匯入預訓練的模型,預訓練模型是現有模型的一個或者幾個部分假設我有一個網路包含 pretrained和classify兩