TensorFlow載入部分模型
阿新 • • 發佈:2021-12-30
詳情參考https://www.cnblogs.com/yibeimingyue/p/11921474.html
本文采用的方式為重寫一樣的graph, 然後恢復指定scope
1.儲存模型部分,通過saver引數,定義要儲存的scope:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
tf.get_collection能生成參與訓練的scope值的列表,可以直接作為Saver的引數,可以按需通過列表切片指定scope.
variables = variables[8:]
saver = tf.train.Saver(variables) # create a saver
saver.save(sess,saver_path)
2.重寫一樣的graph
class Rebuild(object): def __init__(self, batch_size): """ build the graph """
定義兩個saver
saver_vgg = tf.train.Saver(vgg_ref_vars) # 這個是要恢復部分的saver saver = tf.train.Saver() # 這個是當前新圖的saver
在例項化Rebuild類之後,初始化,然後restore
with tf.Session(config=config) as sess: sess.run(init) ... saver_vgg.restore(sess, vgg_graph_weight)#使用匯入圖的saver來恢復