model.clpt轉換為tensorflow serving需要的saved_model.pb
阿新 • • 發佈:2018-12-14
在向tensorflow serving部署模型的時候需要pb格式的模型檔案,但是之前訓練用的是object detection api,訓練生成的是三個ckpt檔案,然後網上和官方轉換都用的是freeze_graph來讀取ckpt,將圖和引數凍結在一個frozen.pb檔案中,這個是不能直接部署到tensorflow serving上的,所以需要用saved_model將ckpt轉換到saved_model.pb,程式碼如下:
import tensorflow as tf import os.path from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants MODEL_DIR = "/home/lyf/models/research/object_detection/faster-rcnn-resnet/model/save" MODEL_NAME = "saved_model.pb" ''' ---saved_model儲存三個步驟--- builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) builder.add_meta_graph_and_variables(sess, ['tag_string']) builder.save() ---載入--- meta_graph_def = tf.saved_model.loader.load(sess, ['tag_string'], saved_model_dir) ''' def saved_model_graph(model_folders): input_checkpoint = model_folders + '/model.ckpt'#獲得ckpt檔案路徑 output_graph = os.path.join(MODEL_DIR,MODEL_NAME) saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph()#獲得預設圖 input_graph_def = graph.as_graph_def() #返回一個序列化的圖代表當前的圖 builder = tf.saved_model.builder.SavedModelBuilder(MODEL_DIR) sigs = {} with tf.Session() as sess: saver.restore(sess, input_checkpoint)#恢復圖並得到資料 builder.add_meta_graph_and_variables(sess,\ [tag_constants.SERVING],\ signature_def_map=sigs) #第一個引數傳入當前的session,包含了graph的結構與所有變數。 #第二個引數是給當前需要儲存的meta_graph一個標籤,標籤名可以自定義,在之後載入模型的時候,需要根據這個標籤名去查詢對應的MetaGraphDef #標籤也可以選用系統定義好的引數,如tf.saved_model.tag_constants.SERVING與tf.saved_model.tag_constants.TRAINING。 #SignatureDef定義了一些協議,對所需的資訊進行封裝. builder.save() if __name__ == '__main__': parser = argparse.ArugumentParser() parser.add_argment("model_folder", type=str, help="input ckpt model dir") aggs = parser.parse.parse_args() saved_model_graph(aggs.model_folders)