1. 程式人生 > 程式設計 >Tensorflow 使用pb檔案儲存(恢復)模型計算圖和引數例項詳解

Tensorflow 使用pb檔案儲存(恢復)模型計算圖和引數例項詳解

一、儲存:

graph_util.convert_variables_to_constants 可以把當前session的計算圖序列化成一個位元組流(二進位制),這個函式包含三個引數:引數1:當前活動的session,它含有各變數

引數2:GraphDef 物件,它描述了計算網路

引數3:Graph圖中需要輸出的節點的名稱的列表

返回值:精簡版的GraphDef 物件,包含了原始輸入GraphDef和session的網路和變數資訊,它的成員函式SerializeToString()可以把這些資訊序列化為位元組流,然後寫入檔案裡:

constant_graph = graph_util.convert_variables_to_constants( sess,sess.graph_def,['sum_operation'] )
with open( pbName,mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始張量(包含在引數1和引數2中的組成部分)不參與引數3指定的輸出節點列表所指定的張量計算的話,這些張量將不會存在返回的GraphDef物件裡,也不會被序列化寫入pb檔案。

二、恢復:

恢復時,建立一個GraphDef,然後從上述的檔案里加載進來,接著輸入到當前的session:

    graph0 = tf.GraphDef()
    with open( pbName,mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0,name = '' )

三、程式碼:

 
import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32,name='var1' ) 
    var2 = tf.Variable( 20,name='var2' )#實參name='var2'指定了操作名,該操作返回的張量名是在
                       #'var2'後面:0,即var2:0 是返回的張量名,也就是說變數
                       # var2的名稱是'var2:0'
    var3 = tf.Variable( 30,name='var3' )
    var4 = tf.Variable( 40,name='var4' )
    var4op = tf.assign( var4,1000,name = 'var4op1' )
    sum = tf.Variable( 4,name='sum' )
    sum = tf.add ( var1,var2,name = 'var1_var2' ) 
    sum = tf.add( sum,var3,name='sum_var3' )
    sumOps = tf.add( sum,var4,name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ',var.name) #張量是tf.Variable/tf.Add之類操作的結果,
                        #張量的名字使用操作名加:0來表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ',sess.run( sumOps,feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess,['sum_operation'] )
    with open( pbName,mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName,name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ',sumTensor )
      print( sess.run( sumTensor,feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()
  

四、儲存pb函式程式碼裡的操作名稱/型別/返回的張量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上這篇Tensorflow 使用pb檔案儲存(恢復)模型計算圖和引數例項詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。