1. 程式人生 > 程式設計 >pytorch對梯度進行視覺化進行梯度檢查教程

pytorch對梯度進行視覺化進行梯度檢查教程

目的: 在訓練神經網路的時候,有時候需要自己寫操作,比如faster_rcnn中的roi_pooling,我們可以視覺化前向傳播的影象和反向傳播的梯度影象,前向傳播可以檢查流程和計算的正確性,而反向傳播則可以大概檢查流程的正確性。

實驗

視覺化rroi_align的梯度

1.pytorch 0.4.1及之前,需要宣告需要引數,這裡將圖片資料宣告為variable

im_data = Variable(im_data,requires_grad=True)

2.進行前向傳播,最後的loss對映為一個一維的張量

pooled_feat = roipool(im_data,rois.view(-1,6))
res = pooled_feat.pow(2).sum()
res.backward()

3.注意求loss的時候採用更加複雜,或者更多的運算(這樣在梯度視覺化的時候效果才更加明顯)

視覺化效果

原始圖片

pytorch對梯度進行視覺化進行梯度檢查教程

梯度視覺化圖片

pytorch對梯度進行視覺化進行梯度檢查教程

原圖+梯度圖

pytorch對梯度進行視覺化進行梯度檢查教程

小結:

可以看到誤差梯度的位置是正確的,誤差是否正確,需要其他方式驗證(暫時沒有思路)

可以看到上面在求loss的時候為:loss = sum(x2),但是如果換成:loss = mean(x),效果就沒有上面明顯。

實驗二的效果

pytorch對梯度進行視覺化進行梯度檢查教程

loss = mean(x)

可以看到根本無法看到誤差梯度的位置資訊

實驗三:loss = sum(x)

pytorch對梯度進行視覺化進行梯度檢查教程

pytorch對梯度進行視覺化進行梯度檢查教程

小結: 可以看到位置資訊有差別,比如國徽部分,這會讓人以為,國徽部分只利用了左部分的資訊,或者自己手寫的操作誤差索引不對。

可以通過兩種方式進行驗證

1.用更多,更復雜的運算求loss,比如pow,等

2.用matplotlib顯示圖片後,用滑鼠可以指示每個點的具體的值,可以檢測有誤差梯度區域是否和無誤差梯度區域有差別。

以上這篇pytorch對梯度進行視覺化進行梯度檢查教程就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。