1. 程式人生 > 實用技巧 >RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

目錄

問題

在用pytorch生成對抗網路的時候,出現錯誤Runtime Error: one of the variables needed for gradient computation has been modified by an inplace operation,特記錄排坑記錄。

環境配置

windows10 2004
python 3.7.4
pytorch 1.7.0 + cpu

解決過程

  • 嘗試一

這段錯誤程式碼看上去不難理解,意思為:計算梯度所需的某變數已被一就地操作修改。什麼是就地操作呢,舉個例子如x += 1就是典型的就地操作,可將其改為y = x + 1

。但很遺憾,這樣並沒有解決我的問題,這種方法的介紹如下。
在網上搜了很多相關部落格,大多原因如下:

由於0.4.0把Varible和Tensor融合為一個Tensor,inplace操作,之前對Varible能用,但現在對Tensor,就會出錯了。

所以解決方案很簡單:將所有inplace操作轉換為非inplace操作。如將x += 1換為y = x + 1
仍然有一個問題,即如何找到inplace操作,這裡提供一個小trick:分階段呼叫y.backward(),若報錯,則說明這之前有問題;反之則說明錯誤在該行之後。

  • 嘗試二

在我的程式碼里根本就沒有找到任何inplace操作,因此上面這種方法行不通。自己盯著程式碼,debug,啥也看不出來,好久......
忽然有了新idea。我的訓練階段的程式碼如下:

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        optimizerD.step()

        # update the generator
        netG.zero_grad()
        # !!!問題出錯行
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()        
        optimizerG.step()

判別器loss的backward是正常的,生成器loss的backward有問題。觀察到g_loss由兩項組成,所以很自然的想法就是刪掉其中一項看是否正常。結果為:只保留第一項程式正常執行;g_loss中包含第二項程式就出錯。
因此去看了adversarialLoss的程式碼:

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.bec_loss = nn.BCELoss()

    def forward(self, logits_fake):
        # Adversarial Loss
        # !!! 問題在這,logits_fake加上detach後就可以正常執行
        adversarial_loss = self.bec_loss(logits_fake, torch.ones_like(logits_fake))
        return 0.001 * adversarial_loss

看不出來任何問題,只能挨個試。這裡只有兩個變數:logits_faketorch.ones_like(logits_fake)。後者為常量,所以試著固定logits_fake,不讓其參與訓練,程式竟能運行了!

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.bec_loss = nn.BCELoss()

    def forward(self, logits_fake):
        # Adversarial Loss
        # !!! 問題在這,logits_fake加上detach後就可以正常執行
        adversarial_loss = self.bec_loss(logits_fake.detach(), torch.ones_like(logits_fake))
        return 0.001 * adversarial_loss

由此知道了被修改的變數是logits_fake。儘管程式可以運行了,但這樣做不一定合理。類AdversarialLoss中沒有對logits_fake進行修改,所以返回剛才的訓練程式中。

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        # 這裡進行的更新操作
        optimizerD.step()

        # update the generator
        netG.zero_grad()
        # !!!問題出錯行
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()        
        optimizerG.step()

注意到Discriminator在出錯行之前進行了更新操作,因此真相呼之欲出————optimizerD.step()logits_fake進行了修改。直接將其挪到倒數第二行即可,修改後程式碼為:

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        

        # update the generator
        netG.zero_grad()
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()   
        optimizerD.step()     
        optimizerG.step()

程式終於正常運行了,耶( •̀ ω •́ )y!

總結

原因:在計算生成器網路梯度之前先對判別器進行更新,修改了某些值,導致Generator網路的梯度計算失敗。
解決方法:將Discriminator的更新步驟放到Generator的梯度計算步驟後面。