1. 程式人生 > 程式設計 >tensorflow沒有output結點,儲存成pb檔案的例子

tensorflow沒有output結點,儲存成pb檔案的例子

Tensorflow中儲存成pb file 需要 使用函式

graph_util.convert_variables_to_constants(sess,sess.graph_def,

output_node_names=[]) []中需要填寫你需要儲存的結點。如果儲存的結點在神經網路中沒有被顯示定義該怎麼辦?

例如我使用了tf.contrib.slim或者keras,在tf的高層很多情況下都會這樣。

在寫神經網路時,只需要簡單的一層層傳導,一個slim.conv2d層就包含了kernal,bias,activation function,非常的方便,好處是網路結構一目瞭然,壞處是什麼呢?

在嘗試儲存pb的 output node names時,需要將最後的輸出結點儲存下來,與這個結點相關的,從輸入開始,經過層層傳遞的巢狀函式或者操作的相關結點,都會被儲存,但無效的例如 計算準確率,計算loss等,就可以省略了,因為儲存的pb主要是用來做預測的。

在準備檢視所有的結點名稱並選取儲存時,發現scope "local3"裡面僅有相關的weights 和biases,這兩個是單獨存在的,即儲存這兩個引數並沒有任何意義。

那麼這時候有兩種解決辦法:

方法一:

graph_util.convert_variables_to_constants(sess,output_node_names=[var.name[:-2] for var in tf.global_variables()])

那麼這個的意思是所有的variable的都被儲存下來 但函式中要求的是 node name 我們通過 global_variables獲得的是 變數名 並不是 節點名

(例如 output:0 就是變數名,又叫tensor name)

output就是 node name了。

在tensorboard中可以一窺究竟

通過這樣 也可以將 所有的變數全部儲存下來(但是你並不能使用,是因為你的output並沒有名字,所以你不可以通過常用的sess.graph.get_tensor_by_name來使用)

方法二:

那就是直接改寫神經網路了....當然了還是比較簡單的,只要改寫最後一個,改寫成output即可,tensorflow中無論是 變數、操作op、函式、都可以命名,那麼這個地方是一個簡單的全連線,僅需要將weights*net(上一層的輸出) +bias 即可,我們只要將bias相加的結果命名為 ouput即可:

with tf.name_scope('local3'):
 
  local3_weights = tf.Variable(tf.truncated_normal([4096,self.output_size],stddev=0.1))
 
  local3_bias = tf.Variable(tf.constant(0.1,shape=[self.output_size]))
 
result = tf.add(tf.matmul(net,local3_weights),local3_bias,name="output")

這樣將上述的convert_variables_to_constants中的output_node_names只需要填寫一個['output']即可,因為這一個output結點,需要從input開始,將所有的神經網路前向傳播的操作和引數全部儲存下來,因此儲存的結點數量 和 方法一儲存的結點數量是一樣的(console顯示都是 convert 24)。

完整的pb儲存為:(我是將ckpt讀入進來,然後存成pb的)

from tensorflow.python.platform import gfile
 
 
 
load_ckpt():
 
  path = './data/output/loss1.0/'
 
  print("read from ckpt")
 
  ckpt = tf.train.get_checkpoint_state(path)
 
  saver = tf.train.Saver()
 
  saver.restore(sess,ckpt.model_checkpoint_path)
 
 
 
def write2pb_file():
 
  constant_graph = graph_util.convert_variables_to_constants(sess,output_node_names=["output"])
 
  with tf.gfile.GFile(path+'loss1.0.pb',mode='wb') as f:
 
  f.write(constant_graph.SerializeToString())
 
  print("Model is saved as " + path+'loss1.0.pb')
 
 
 
def main():
 
  load_ckpt()
 
  write2pb_file()
 

如果是簡單的直接儲存,那就更簡單了。

pb檔案的read,很多人會將一個net寫成一個類,在引入的時候會將新建這個類,然後讀入ckpt檔案,這完全沒有問題,但是在讀取pb時,就會發生問題,因為pb中已經包含了圖與引數,引入時會建立一個預設的圖,但是net類中自己也會建立一個圖,那麼這時候你執行程式,引數其實並沒有使用.pb的檔案。

所以我們不能建立net類,然後直接讀入.pb檔案,對.pb檔案,通過如下程式碼,獲取.pb的graph中的輸入和輸出。

self.output = self.sess.graph.get_tensor_by_name("output:0")
 
self.input = self.sess.graph.get_tensor_by_name("images:0")

注意此時要加:0 因為你獲取的不再是結點了,而是一個真實的變數,我的理解是,結點相當於一個類,:0是物件,預設初始化值就是物件的初始化。

然後就可以通過self.sess.run(self.output(feed_dict={self.input: your_input})))執行你的網路了!

以上這篇tensorflow沒有output結點,儲存成pb檔案的例子就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。