tensorflow 怎麼只訓練一部分變數
阿新 • • 發佈:2019-01-06
After looking at the documentation and the code, I was not able to find a way to remove a Variable from the TRAINABLE_VARIABLES
.
Here is what happens:
- The first time
tf.get_variable('weights', trainable=True)
TRAINABLE_VARIABLES
. - The second time you call
tf.get_variable('weights', trainable=False)
, you get the same variable but the argumenttrainable=False
has no effect as the variable is already present in the list ofTRAINABLE_VARIABLES
First solution
When calling the minimize
method of the optimizer (see doc.), you can pass a var_list=[...]
as argument with the variables you want to optimizer.
For instance, if you want to freeze all the layers of VGG except the last two, you can pass the weights of the last two layers in var_list
Second solution
You can use a tf.train.Saver()
to save variables and restore them later (see this tutorial).
- First you train your entire VGG model with all trainable variables. You save them in a checkpoint file by calling
saver.save(sess, "/path/to/dir/model.ckpt")
. - Then (in another file) you train the second version with non trainable variables. You load the variables previously stored with
saver.restore(sess, "/path/to/dir/model.ckpt")
.
Optionally, you can decide to save only some of the variables in your checkpoint file. See the doc for more info.