1. 程式人生 > >tensorflow 儲存模型和取出中間權重

tensorflow 儲存模型和取出中間權重

下面程式碼的功能是先訓練一個簡單的模型,然後儲存模型,同時儲存到一個pb檔案當中,後續可以從pd檔案裡讀取權重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#設定使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面這段程式碼是在訓練好之後將所有的權重名字和權重值羅列出來,訓練的時候需要註釋掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100') variables = reader.get_variable_to_shape_map() for ele in variables: print(ele) print(reader.get_tensor(ele)) x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1
])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False#設成True去訓練模型 train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b)) graph_def = tf.get_default_graph().as_graph_def() #通過修改下面的函式,個人覺得理論上能夠實現修改權重,但是很複雜,如果哪位有好辦法,歡迎指教 output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable']) with tf.gfile.FastGFile('./test.pb', 'wb') as f: f.write(output_graph_def.SerializeToString()) with tf.Session() as sess: #對應最後一部分的寫,這裡能夠將對應的變數取出來 with gfile.FastGFile('./test.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) res = tf.import_graph_def(graph_def, return_elements=['Variable:0']) print(sess.run(res)) print(sess.run(graph_def))

相關推薦

tensorflow 儲存模型取出中間權重

下面程式碼的功能是先訓練一個簡單的模型,然後儲存模型,同時儲存到一個pb檔案當中,後續可以從pd檔案裡讀取權重值。 import tensorflow as tf import numpy as np import os import h5py impor

tensorflow儲存模型恢復模型

儲存模型 w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8} w3 =

tensorflow儲存訓練的權重為.pb,然後讀取.pb並使用

(1)tensorflow儲存圖和訓練好的權重 from __future__ import absolute_import, unicode_literals import input_data import tensorflow as tf import shutil

tensorflow儲存模型、載入模型提取模型引數特徵圖

1.tf.train.latest_checkpoint('./model_data/')這一句最終返回的是一個字串,比如'./model_data/model-99991'這個方法本身還會做相應的檢查,比如checkpoint中最新的模型model_checkpoint_p

tensorflow 儲存模型

#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Oct 25 15:29:59 2018 @author: lg """ import tensorflow as tf import numpy as

Tensorflow儲存模型,恢復模型

tensorflow從已經訓練好的模型中,恢復(指定)權重(構建新變數、網路)並繼續訓練(finetuning) https://blog.csdn.net/ying86615791/article/details/76215363   Tensorflow儲存模型,恢復模型,

TensorFlow 儲存模型為 PB 檔案

通常我們使用 TensorFlow時儲存模型都使用 ckpt 格式的模型檔案,使用類似的語句來儲存模型 tf.train.Saver().save(sess,ckpt_file_path,max_to_keep=4,keep_checkpoint_ever

深度學習【13】tensorflow儲存graph引數為pb檔案

from tensorflow.python.framework.graph_util import convert_variables_to_constants graph = convert_v

Java 儲存模型共享物件詳解

Java 儲存模型和共享物件詳解 很多程式設計師對一個共享變數初始化要注意可見性和安全釋出(安全地構建一個物件,並其他執行緒能正確訪問)等問題不是很理解,認為Java是一個遮蔽記憶體細節的平臺,連物件回收都不需要關心,因此談到可見性和安全釋出大多不知所云。其實關鍵在於對Java儲存模型,可見性和

keras 儲存模型載入模型

import numpy as np np.random.seed(1337) # for reproducibility from keras.models import Sequential from keras.layers import Dense from k

--Set * 無序(儲存順序取出順序不一致),唯一

package cn.itcast_01; import java.util.HashSet; import java.util.Set; /*  * Collection  * |--List  * 有序(儲存順序和取出順序一致),可重複  * |--Set  * 無序(

keras儲存模型載入模型

1、儲存模型和載入模型的方法 用實驗室的伺服器跑神經網路的時候伺服器老是斷開連線,這對我的訓練和測試來時是一件比較崩潰的事,因為這意味著我要重新訓練一次,要浪費又一次的時間,所以我在網上百度了儲存模型和載入模型的辦法,大部分的方法如下: 儲存模型 model.s

TensorFlow 模型權重儲存及預測

TensorFlow模型和權重的儲存 因為大肥狼在使用儲存的模型和權重進行預測時遇到了一些問題,所以本文將介紹如何在TensorFlow中儲存模型和權重,並如何使用儲存的模型和權重來進行預測。 1.程式碼 我們的程式碼主要是吳恩達作業第二單元第三週-----tensorflow入

TensorFlow儲存提取模型

一、模型的儲存 將訓練好的模型引數儲存起來,以便以後進行驗證或測試,這是我們經常要做的事情。tf裡面提供模型儲存的是tf.train.Saver()模組。 1、模型儲存,先要建立一個Saver物件:如 saver=tf.train.Saver() __init__( var_

TensorFlow儲存載入訓練模型

儲存:使用saver.save()方法儲存 載入:使用saver.restore()方法載入 下面是個完整例子: 儲存: import tensorflow as tf W = tf.Variable([[1, 1, 1], [2, 2, 2]], dtype=tf.float

tensorflow儲存載入模型

× TF 儲存和載入模型 <!-- 作者區域 --> <div class="author"> <a class="avatar" href="/u/ff5c

Tensorflow載入預訓練模型儲存模型

使用tensorflow過程中,訓練結束後我們需要用到模型檔案。有時候,我們可能也需要用到別人訓練好的模型,並在這個基礎上再次訓練。這時候我們需要掌握如何操作這些模型資料。看完本文,相信你一定會有收穫! 1 Tensorflow模型檔案 我們在checkpo

深度學習tensorflow實戰筆記(3)VGG-16訓練自己的資料並測試儲存模型

    前面的部落格介紹瞭如何把影象資料轉換成tfrecords格式並讀取,本篇部落格介紹如何用自己的tfrecords格式的資料訓練CNN模型,採用的模型是VGG-16。現有的教程都是在mnist或者cifar-10資料集上做的訓練,如何用自己的資料集進行訓練相關的資料比較

tensorflow 檢查點模型,儲存與恢復使用,官方教程(一)

檢查點:這種格式依賴於建立模型的程式碼。SavedModel:這種格式與建立模型的程式碼無關。示例程式碼本文件依賴於 TensorFlow 使用入門中詳細介紹的同一個鳶尾花分類示例。要下載和訪問該示例,請執行下列兩個命令:git clone https://github.co

Tensorflow的變數模型儲存以及模型應用

Table of Contents 一、模型部分(成功) 1.儲存的模型 2.載入模型並用於預測 1.載入圖結構和引數 2.獲取圖 3.獲取tensor 4.新的inpu