1. 程式人生 > >tensorflow框架.ckpt .pb模型節點tensor_name列印及ckpt模型轉.pb模型

tensorflow框架.ckpt .pb模型節點tensor_name列印及ckpt模型轉.pb模型

轉換模型首先要知道的是從哪個節點輸出,如果沒有原始碼是很難清楚節點資訊。

獲取ckpt模型的節點名稱

import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in
var_to_shape_map: print("tensor_name: ", key) # print(reader.get_tensor(key)) #相應的值

獲取pb模型的節點名稱

import tensorflow as tf
import os

model_dir = './'
model_name = 'model.pb'

def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name=''
) create_graph() tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] for tensor_name in tensor_name_list: print(tensor_name,'\n')

ckpt轉換為pb模型

from tensorflow.python.tools import inspect_checkpoint as chkp
import tensorflow as tf

saver = tf.train.import
_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True) #【敲黑板!】這裡就是填寫輸出節點名稱惹 output_nodes = ["xxx"] with tf.Session(graph=tf.get_default_graph()) as sess: input_graph_def = sess.graph.as_graph_def() saver.restore(sess, "./ade20k/model.ckpt-27150") output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_nodes) with open("frozen_model.pb", "wb") as f: f.write(output_graph_def.SerializeToString())