1. 程式人生 > 其它 >RuntimeError: Trying to backward through the graph a second time

RuntimeError: Trying to backward through the graph a second time

起因是把別人的用clip做分割的模型加到自己的框架上,結果報這個錯。Google了一下,發現可能是如下幾種原因:多個loss都要backward卻沒有retain graphhttps://www.zhihu.com/question/414980879,或者是rnn時對於前一次的輸出沒有detach就送進網路等等,還有一些奇怪的原因比如https://www.zhihu.com/search?type=content&q=RuntimeError%3A%20Trying%20to%20backward%20through%20the,結果發現和自己的情況都不符合。後來看到某CSDN的一個帖子https://blog.csdn.net/qq_49030008/article/details/125440817,雖然和自己的情況也不太一樣,但提到的預訓練模型啟發了我:現在跑的不就是clip做分割的任務嗎!於是開始一通亂改,比如把embedding後的text feature給detach或者儲存在迴圈外,每次forward的時候傳進來,等等,結果都不work。無奈之下只好跑起來官方程式碼對拍,但官方程式碼用了Pytorch lighting,封裝了不少東西,其餘的地方看起來貌似都沒啥特別的...最後在某一次除錯的時候列印了一下text feature的is_leaf和requires_grad屬性,發現兩輪後這兩個屬性竟然會發生反轉!仔細一看發現前兩輪不是真正在train,而是進行了一次驗證(可以自行查閱lighting框架的num_sanity_val_steps引數),猜想可能是在跑這兩次測試的過程中對模型引數屬性進行了一些奇妙的初始化,於是檢視框架原始碼https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/trainer/trainer.py:

with torch.no_grad():
	val_loop.run()

發現其實就是跑了一下val的loop,於是對自己的程式碼在訓練前加上一部分:

with torch.no_grad():
	for cur_step, (images, labels) in enumerate(train_loader):
    images = images.to(device, dtype=torch.float32)
    outputs = model(images, labelset='')
    break

果然不報錯了,但原理是什麼還未搞懂。已經在 GitHub提了一個issue,希望能看到作者給的答案https://github.com/isl-org/lang-seg/issues/38