1. 程式人生 > >model.clpt轉換為tensorflow serving需要的saved_model.pb

model.clpt轉換為tensorflow serving需要的saved_model.pb

    在向tensorflow serving部署模型的時候需要pb格式的模型檔案,但是之前訓練用的是object detection api,訓練生成的是三個ckpt檔案,然後網上和官方轉換都用的是freeze_graph來讀取ckpt,將圖和引數凍結在一個frozen.pb檔案中,這個是不能直接部署到tensorflow serving上的,所以需要用saved_model將ckpt轉換到saved_model.pb,程式碼如下:

import tensorflow as tf
import os.path
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

MODEL_DIR = "/home/lyf/models/research/object_detection/faster-rcnn-resnet/model/save"
MODEL_NAME = "saved_model.pb"

'''
---saved_model儲存三個步驟---
	builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) 
	builder.add_meta_graph_and_variables(sess, ['tag_string']) 
	builder.save()
---載入---
	meta_graph_def = tf.saved_model.loader.load(sess, ['tag_string'], saved_model_dir)
'''

def saved_model_graph(model_folders):

	input_checkpoint = model_folders + '/model.ckpt'#獲得ckpt檔案路徑
	output_graph = os.path.join(MODEL_DIR,MODEL_NAME)
	saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
	graph = tf.get_default_graph()#獲得預設圖
	input_graph_def = graph.as_graph_def()  #返回一個序列化的圖代表當前的圖
	builder = tf.saved_model.builder.SavedModelBuilder(MODEL_DIR)
	sigs = {}
	with tf.Session() as sess:
		saver.restore(sess, input_checkpoint)#恢復圖並得到資料
		builder.add_meta_graph_and_variables(sess,\
						     [tag_constants.SERVING],\
 						     signature_def_map=sigs)
		#第一個引數傳入當前的session,包含了graph的結構與所有變數。
		#第二個引數是給當前需要儲存的meta_graph一個標籤,標籤名可以自定義,在之後載入模型的時候,需要根據這個標籤名去查詢對應的MetaGraphDef
		#標籤也可以選用系統定義好的引數,如tf.saved_model.tag_constants.SERVING與tf.saved_model.tag_constants.TRAINING。
		#SignatureDef定義了一些協議,對所需的資訊進行封裝.                                        	
		
	builder.save()

if __name__ == '__main__':
	parser = argparse.ArugumentParser()
	parser.add_argment("model_folder", type=str, help="input ckpt model dir")
	aggs = parser.parse.parse_args()
	saved_model_graph(aggs.model_folders)