1. 程式人生 > 程式設計 >在tensorflow實現直接讀取網路的引數(weight and bias)的值

在tensorflow實現直接讀取網路的引數(weight and bias)的值

訓練好了一個網路,想要檢視網路裡面引數是否經過BP演算法優化過,可以直接讀取網路裡面的引數,如果一直是隨機初始化的值,則證明訓練程式碼有問題,需要改。

下面介紹如何直接讀取網路的weight 和 bias。

(1) 獲取引數的變數名。可以使用一下函式獲取變數名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

輸入你想要讀取的變數的一部分的名稱(scope_name_var),然後通過這個函式返回一個List,裡面是所有含有這個名稱的變數。

(2) 利用session讀取變數的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化變數
  print(sess.run(full_connect_variable[0]))

之後如果想要看引數隨著訓練的變化,你可以將這些引數儲存到一個txt檔案裡面檢視。

補充知識:

如何在 PyTorch 中設定學習率衰減(learning rate decay)

在tensorflow實現直接讀取網路的引數(weight and bias)的值

很多時候我們要對學習率(learning rate)進行衰減,下面的程式碼示範瞭如何每30個epoch按10%的速率衰減:

def adjust_learning_rate(optimizer,epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什麼是param_groups?

optimizer通過param_group來管理引數組.param_group中儲存了引數組及其對應的學習率,動量等等.所以我們可以通過更改param_group[‘lr']的值來更改對應引數組的學習率。

# 有兩個`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},{'params': model.classifier.parameters(),'lr': 1e-3}
   ],lr=1e-2,momentum=0.9)
 
#一個引數組
optim.SGD(model.parameters(),momentum=.9)

以上這篇在tensorflow實現直接讀取網路的引數(weight and bias)的值就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。