1. 程式人生 > 程式設計 >Tensorflow 自定義loss的情況下初始化部分變數方式

Tensorflow 自定義loss的情況下初始化部分變數方式

一般情況下,tensorflow裡面變數初始化過程為:

  #variables ...........
  #..................... 
  init = tf.initialize_all_variables()
  sess.run(init)

這裡 tf.initialize_all_variables() 會初始化所有的變數。

實際過程中,假設有a,b,c三個變數,其中a已經被初始化了,只想單獨初始化b,c,那麼:

  #variables ...
  ...
  init = tf.variables_initializer([b,c])
  sess.run(init)

此外,如果自行修改了optimizer,如下程式碼就會報錯:

  #definition of variables a,c ...
  ....
  my_optimizer = tf.train.RMSProp(learning_rate = 0.1).minimize(my_cost)
  init = tf.variables_initializer([b,c])
  sess.run(init)

這是因為自己定義的optimizer會生成新的variables,但是在init裡面並沒有初始化,所以無法訪問,會報錯。解決方法如下:

  a = tf.Variables(...)      #line N
  temp = set(tf.all_variables()) 
  b = tf.Variables(...)
  c = tf.Variables(...) 
  #definition of my optimizer
  optimizer = tf.train.......
  init = tf.variables_initializer(set(tf.all_varialbles())-temp) # line M
  sess.run(init)

首先,temp = set(tf.all_variables()) 將該行(line N)程式碼之前的所有變數儲存在temp中,接下來定義變數b,c,以及自定義的optimizer,然後 set(tf.all_varialbles()儲存了改行(line M)之前的所有變數(包括optimizer生成的變數以及temp中所含的變數),set(tf.all_varialbles())-temp相減得到line N~M這幾行定義的變數。

以上這篇Tensorflow 自定義loss的情況下初始化部分變數方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。