1. 程式人生 > >在一個專案中匯入多個不同tensorflow模型

在一個專案中匯入多個不同tensorflow模型

剛開始直接採用呼叫一個模型的方法:
(1)定義網路
(2)新建sess:sess = tf.Session(config=config)
(3)定義saver:saver = tf.train.Saver()
(4)匯入權重:saver.restore(sess, xxx)
但是,如果在一個專案中同時匯入多個模型,會報錯,應該是graph衝突,所以需要給每個模型單獨新建graph:

g1 = tf.Graph()
isess = tf.Session(graph=g1)
with g1.as_default():
    (定義網路模型結構)
    isess.run(tf.global
_variables_initializer()) saver = tf.train.Saver() saver.restore(isess, xxx)#xxx為ckpt路徑 g2 = tf.Graph() isess2 = tf.Session(graph=g2) with g2.as_default(): (定義網路模型結構) isess2.run(tf.global_variables_initializer()) saver2 = tf.train.Saver() saver2.restore(isess2, xxx) g3... ...