在一個專案中匯入多個不同tensorflow模型
阿新 • • 發佈:2019-01-24
剛開始直接採用呼叫一個模型的方法:
(1)定義網路
(2)新建sess:sess = tf.Session(config=config)
(3)定義saver:saver = tf.train.Saver()
(4)匯入權重:saver.restore(sess, xxx)
但是,如果在一個專案中同時匯入多個模型,會報錯,應該是graph衝突,所以需要給每個模型單獨新建graph:
g1 = tf.Graph()
isess = tf.Session(graph=g1)
with g1.as_default():
(定義網路模型結構)
isess.run(tf.global _variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, xxx)#xxx為ckpt路徑
g2 = tf.Graph()
isess2 = tf.Session(graph=g2)
with g2.as_default():
(定義網路模型結構)
isess2.run(tf.global_variables_initializer())
saver2 = tf.train.Saver()
saver2.restore(isess2, xxx)
g3...
...