1. 程式人生 > >Tensorflow模型轉換 ckpt轉pb h5轉pb

Tensorflow模型轉換 ckpt轉pb h5轉pb

此篇部落格重在總結Tensorflow,Keras模型訓練的模型檔案轉換為pb結構的方式,節省尋找轉換工具的時間。

1. Tensorflow ckpt模型轉換pb模型

我們在Tensorflow中訓練出來的模型一般是ckpt格式的,一個ckpt檔案對應有xxx.ckpt.dataxxx.ckpt.metaxxx.ckpt.index三個內容。

而在生產環境中,一般C++只能載入pb的模型,即將ckpt的結構3合1,一個模型只對應一個pb(當然甚至可能多個模型也能合成為一個pb,這裡不進行展開)。

廢話不說了,上程式碼

def freeze_graph
(input_checkpoint, output_graph): ''' :param input_checkpoint: xxx.ckpt(千萬不要加後面的xxx.ckpt.data這種,到ckpt就行了!) :param output_graph: PB模型儲存路徑 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt檔案狀態是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt檔案路徑
# 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點 output_node_names = "softmax" # 模型輸入節點,根據情況自定義 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 獲得預設的圖 input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當前的圖 with tf.Session(
) as sess: saver.restore(sess, input_checkpoint) # 恢復圖並得到資料 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變數值固定 sess=sess, input_graph_def=input_graph_def,# 等於:sess.graph_def output_node_names=output_node_names.split(","))# 如果有多個輸出節點,以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: #儲存模型 f.write(output_graph_def.SerializeToString()) #序列化輸出 print("%d ops in the final graph." % len(output_graph_def.node)) #得到當前圖有幾個操作節點

用法

input_checkpoint = 'xxx.ckpt'
out_graph = 'froze_xxx.pb'
freeze_graph(input_checkpoint, out_graph) 

2. Keras h5模型轉換pb模型

現在keras和Tensorflow的整合也越來越緊密了,使用者可以通過tf.contrib.keras在tensorflow中引入keras使用,即keras和tensorflow相互耦合,而非之前那樣,只是tensorflow的高層封裝。

因為keras的很多ops封裝的很簡單,所以現在一般用keras搭模型的人很多,那麼問題來了,如果想在生產環境中使用keras框架產生的hdf5格式的模型檔案,也需要將其轉換為pb格式,怎麼做呢?Let’s roll it!

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a prunned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    prunned so subgraphs that are not neccesary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph


input_fld = sys.path[0]
weight_file = 'vgg16_without_dropout.h5'
output_graph_name = 'vgg16_without_dropout.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
    os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)


print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

3. 參考資料

[1] Eileng: keras模型儲存為tensorflow的二進位制模型
[2] 嘿芝麻:tensorflow框架.ckpt .pb模型節點tensor_name列印及ckpt模型轉.pb模型