Pytorch模型的儲存與載入
前言
在使用Pytorch訓練模型的時候,經常會有在GPU上儲存模型然後再CPU上執行的需求,在實驗的過程中發現在多GPU上訓練的Pytorch模型是不能在CPU上直接執行的,幾次遇到了這種問題,這裡研究和記錄一下。
模型的儲存與載入
例如我們建立了一個模型:
model = MyVggNet()
如果使用多GPU訓練,我們需要使用這行程式碼:
model = nn.DataParallel(model).cuda()
執行這個程式碼之後,model就不在是我們原來的模型,而是相當於在我們原來的模型外面加了一層支援CPU執行的外殼,這時候真正的模型物件為:real_model = model.module
Pytorch有多種儲存模型的方式,使用哪種進行儲存,就要使用對應的載入方式。儲存的時候模型的字尾名是無所謂的。
Pytorch官方的載入和儲存模型的方式有兩種:
1. 儲存和載入整個模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
- 僅儲存和載入模型引數(推薦使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
模型儲存與載入對應方式
1. 第一種方式
儲存使用:
real_model = model.module
torch.save(real_model.state_dict(),os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_weight.pth"))
cpu上載入使用:
args.weight=checkpoint/cos_mnist_10_weight.pth
map_location = lambda storage, loc: storage
model.load_state_dict(torch.load(args.weight,map_location=map_location))
2. 第二種方式
儲存使用:
real_model = model.module
save_model(real_model, os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_weight_cpu.pth"))
# 自定義的函式
def save_model(model,filename):
state = model.state_dict()
for key in state: state[key] = state[key].clone().cpu()
torch.save(state, filename)
cpu上載入使用:
args.weight=checkpoint/cos_mnist_10_weight_cpu.pth
model.load_state_dict(torch.load(args.weight))
3. 第三種方式
儲存使用:
real_model = model.module
torch.save(real_model, os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_whole.pth"))
cpu上載入使用:
args.weight=checkpoint/cos_mnist_10_whole.pth
map_location = lambda storage, loc: storage
model = torch.load(args.weight,map_location=map_location)
參考文獻
相關推薦
[PyTorch 學習筆記] 7.1 模型儲存與載入
> 本章程式碼: > > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py](https://github.com/zhangxiann/PyTorch_Practice/b
TensorFlow實現Softmax迴歸(模型儲存與載入)
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Thu Oct 18 18:02:26 2018 4 5 @author: zhen 6 """ 7 8 from tensorflow.examples.tutorials.mnist imp
python opencv3.x中支援向量機(svm)模型儲存與載入問題
親自驗證,可以解決svm的模型載入問題: import numpy as np from sklearn import datasets &nb
keras中訓練好的模型儲存與載入
keras中的採用Sequential模式建立DNN並持久化保持、重新載入 def DNN_base_v1(X_train, y_train): model = models.Sequential() model.add(layers.Dense(96,
Keras 深度學習程式碼筆記——模型儲存與載入
你可以使用model.save(filepath)將Keras模型和權重儲存在一個HDF5檔案中,該檔案將包含: 模型的結構,以便重構該模型 模型的權重 訓練配置(損失函式,優化器等) 優化器的狀態,以便於從上次訓練中斷的地方開始 使用keras.mod
tensorflow 模型儲存與載入
在訓練一個神經網路模型後,你會儲存這個模型未來使用或部署到產品中。所以,什麼是TF模型?TF模型基本包含網路設計或圖,與訓練得到的網路引數和變數。因此,TF模型具有兩個主要檔案: a)meta圖 這是一個擬定的快取,包含了這個TF圖完整資訊;如所有變數等
Keras中的模型儲存與載入
from keras.models import Sequential from keras.layers import Dense from keras.models import load_model model = Sequential() model.add(Dens
Pytorch模型的儲存與載入
前言 在使用Pytorch訓練模型的時候,經常會有在GPU上儲存模型然後再CPU上執行的需求,在實驗的過程中發現在多GPU上訓練的Pytorch模型是不能在CPU上直接執行的,幾次遇到了這種問題,這裡研究和記錄一下。 模型的儲存與載入 例如我們建立了一
【小白學PyTorch】19 TF2模型的儲存與載入
【新聞】:機器學習煉丹術的粉絲的人工智慧交流群已經建立,目前有目標檢測、醫學影象、時間序列等多個目標為技術學習的分群和水群嘮嗑的總群,歡迎大家加煉丹兄為好友,加入煉丹協會。微信:cyx645016617. 參考目錄: [TOC] 本文主要講述TF2.0的模型檔案的儲存和載入的多種方法。主要分成兩型別:模型
tensorflow模型的儲存與載入
1.儲存:(儲存的變數都是停放,tf.Variable()中的變數,變數一定要有名字) saver = tf.train.Saver() saver.run(sess,"./model4/line_model.ckpt") 2.檢視儲存的變數資訊:(將儲存的資訊打印
基於pytorch的 儲存和載入模型引數
當我們花費大量的精力訓練完網路,下次預測資料時不想再(有時也不必再)訓練一次時,這時候torch.save(),torch.load()就要登場了。 儲存和載入模型引數有兩種方式: 方式一: torch.save(net.state_dict(),path): 功能
Keras儲存與載入模型(JSON+HDF5)
在Keras中,有時候需要對模型進行序列化與反序列化。進行模型序列化時,會將模型結果與模型權重儲存在不同的檔案中,模型權重通常儲存在HDF5檔案中,模型的結構可以儲存在JSON或者YAML檔案中。後二者方法大同小異,這裡以JSON為例說明一下Keras模型的儲存與載入。 from sklearn
python sklearn svm模型的儲存與載入呼叫
對於機器學習的一些模型,跑完之後,如果下一次測試又需要重新跑一遍模型是一件很繁瑣的事,這時候我們就需要儲存模型,再載入呼叫。 樓主發現有這些儲存模型的方法,網上有很多錯誤的例子,所以給大家在整理一下。(python3) 1.利用pickle import pickle
[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式)
[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式) 在上篇博文中,我們探索了TensorFlow模型引數儲存與載入實現方法採用的是儲存ckpt的方式。這篇博文我們會使用儲存為pd格式檔案來實現。 首先,我會在上篇博文基礎上,實現由c
[TensorFlow深度學習入門]實戰七·簡便方法實現TensorFlow模型引數儲存與載入(ckpt方式)
[TensorFlow深度學習入門]實戰七·簡便方法實現TensorFlow模型引數儲存與載入(ckpt方式) TensorFlow模型訓練的好網路引數如果想重複高效利用,模型引數儲存與載入是必須掌握的模組。本文提供一種簡單容易理解的方式來實現上述功能。參考部落格地址 備註: 本文采用的
模型儲存,載入與呼叫
模型儲存 BP: model.save(save_dir) SVM: from sklearn.externals import joblib joblib.dump(clf, save_dir) 模型載入 BP: from keras.models im
pytorch資料載入、模型儲存及載入
主要涉及的Pytorch官方示例下圖紅框部分的一些翻譯及備註。 1、資料載入及處理 該部分主要是用於進行資料集載入及資料預處理說明,使用的資料集為:人臉+標註座標。demo程式需要pandas(讀取CSV檔案)及scikit-image(影象變換)這兩個包。 1.1、jup
Keras 儲存與載入網路模型
遇到問題: keras使用預訓練模型做訓練時遇到的如下程式碼: from keras.utils.data_utils import get_file WEIGHTS_PATH = 'https://github.com/fchollet/deep-lea
TensorFlow下網路模型的儲存與載入
#!/usr/bin/env python# 匯入mnist資料庫from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)i
模型的儲存與載入
tensorflow: 有兩種方式儲存和載入模型。 ①生成checkpoint file,副檔名為.ckpt,通過在tf.train.Saver物件上呼叫Saver.save()生成。包含權重和變數,但不包括圖的結構。如果需要在另一個程式中使用,需要重新建立