1. 程式人生 > >loss function with value of NAN

loss function with value of NAN

根據網上的資料,可能的情況就是1. 梯度太大。2. 計算過程中可能出現了除零的出錯。

試過改變梯度無效後,確定問題出在其中一個自定義的loss函式,必須把這個函式的每一步計算是否導致零考察。

關於pytorch自動求導的基本介紹如下:

設計如下測試,

a = torch.ones(2, 2, requires_grad=True)
b = torch.ones(2, 2, requires_grad=True)

a = a + 0.001
#b = b + 0.002

x = (a - b).pow_(2).sum(1).sqrt_() # 該函式中的a如果等於b,會導致反向之後a和b的梯度為nan。


out = (x*x).mean()


out = out.backward()

print(a.grad)  # 檢視輸入張量的梯度

可以得出是pow和sum,sqrt中的一個函數出錯,不細究了。僅此記錄