在tensorflow實現直接讀取網路的引數(weight and bias)的值
阿新 • • 發佈:2020-06-26
訓練好了一個網路,想要檢視網路裡面引數是否經過BP演算法優化過,可以直接讀取網路裡面的引數,如果一直是隨機初始化的值,則證明訓練程式碼有問題,需要改。
下面介紹如何直接讀取網路的weight 和 bias。
(1) 獲取引數的變數名。可以使用一下函式獲取變數名:
def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]
輸入你想要讀取的變數的一部分的名稱(scope_name_var),然後通過這個函式返回一個List,裡面是所有含有這個名稱的變數。
(2) 利用session讀取變數的值:
def get_weight(self): full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ##一定要先初始化變數 print(sess.run(full_connect_variable[0]))
之後如果想要看引數隨著訓練的變化,你可以將這些引數儲存到一個txt檔案裡面檢視。
補充知識: 如何在 PyTorch 中設定學習率衰減(learning rate decay)
很多時候我們要對學習率(learning rate)進行衰減,下面的程式碼示範瞭如何每30個epoch按10%的速率衰減:
def adjust_learning_rate(optimizer,epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args.lr * (0.1 ** (epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr
什麼是param_groups?
optimizer通過param_group來管理引數組.param_group中儲存了引數組及其對應的學習率,動量等等.所以我們可以通過更改param_group[‘lr']的值來更改對應引數組的學習率。
# 有兩個`param_group`即,len(optim.param_groups)==2 optim.SGD([ {'params': model.base.parameters()},{'params': model.classifier.parameters(),'lr': 1e-3} ],lr=1e-2,momentum=0.9) #一個引數組 optim.SGD(model.parameters(),momentum=.9)
以上這篇在tensorflow實現直接讀取網路的引數(weight and bias)的值就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。