pytorch nn.functional.dropout的坑
阿新 • • 發佈:2018-12-30
剛踩的坑, 差點就哭出來了TT. --- 我明明加了一百個dropout, 為什麼結果一點都沒變
使用F.dropout ( nn.functional.dropout )的時候需要設定它的training這個狀態引數與模型整體的一致.
比如:
Class DropoutFC(nn.Module):
def __init__(self):
super(DropoutFC, self).__init__()
self.fc = nn.Linear(100,20)
def forward(self, input):
out = self.fc( input)
out = F.dropout(out, p=0.5)
return out
Net = DropoutFC()
Net.train()
# train the Net
這段程式碼中的F.dropout實際上是沒有任何用的, 因為它的training狀態一直是預設值False. 由於F.dropout只是相當於引用的一個外部函式, 模型整體的training狀態變化也不會引起F.dropout這個函式的training狀態發生變化. 所以, 此處的out = F.dropout(out) 就是 out = out. Ref: https://github.com/pytorch/pyto rch/blob/master/torch/nn/functional.py#L535
正確的使用方法如下, 將模型整體的training狀態引數傳入dropout函式
Class DropoutFC(nn.Module):
def __init__(self):
super(DropoutFC, self).__init__()
self.fc = nn.Linear(100,20)
def forward(self, input):
out = self.fc(input)
out = F.dropout(out, p=0.5, training =self.training)
return out
Net = DropoutFC()
Net.train()
# train the Net
Class DropoutFC(nn.Module):
def __init__(self):
super(DropoutFC, self).__init__()
self.fc = nn.Linear(100,20)
self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
out = self.fc(input)
out = self.dropout(out)
return out
Net = DropoutFC()
Net.train()
# train the Net