1. 程式人生 > >Pytorch半精度浮點型網路訓練問題

Pytorch半精度浮點型網路訓練問題

用Pytorch1.0進行半精度浮點型網路訓練需要注意下問題:

1、網路要在GPU上跑,模型和輸入樣本資料都要cuda().half()

2、模型引數轉換為half型,不必索引到每層,直接model.cuda().half()即可

3、對於半精度模型,優化演算法,Adam我在使用過程中,在某些引數的梯度為0的時候,更新權重後,梯度為零的權重變成了NAN,這非常奇怪,但是Adam演算法對於全精度資料型別卻沒有這個問題。

  另外,SGD演算法對於半精度和全精度計算均沒有問題。

 

還有一個問題是不知道是不是網路結構比較小的原因,使用半精度的訓練速度還沒有全精度快。這個值得後續進一步探索。