1. 程式人生 > 其它 >TensorFlow載入部分模型

TensorFlow載入部分模型

詳情參考https://www.cnblogs.com/yibeimingyue/p/11921474.html

本文采用的方式為重寫一樣的graph, 然後恢復指定scope

1.儲存模型部分,通過saver引數,定義要儲存的scope:

variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  tf.get_collection能生成參與訓練的scope值的列表,可以直接作為Saver的引數,可以按需通過列表切片指定scope.
variables = variables[8:]
 saver = tf.train.Saver(variables)  # create a saver
saver.save(sess,saver_path)

  2.重寫一樣的graph

class Rebuild(object):
    def __init__(self, batch_size):
        """
        build the graph
        """

  定義兩個saver

saver_vgg = tf.train.Saver(vgg_ref_vars) # 這個是要恢復部分的saver
saver = tf.train.Saver() # 這個是當前新圖的saver

  在例項化Rebuild類之後,初始化,然後restore

with tf.Session(config=config) as sess:
sess.run(init)
...
saver_vgg.restore(sess, vgg_graph_weight)#使用匯入圖的saver來恢復