1. 程式人生 > >PyTorch —— LeNet實現中的bug以及由此的小想法

PyTorch —— LeNet實現中的bug以及由此的小想法

經典的LeNet,在PyTorch/examples/mnist 實現中有個小問題,在這裡和大家分享一下。

是一個計算generalization error的問題。

計算generalization error時,原始碼有一行我一直不理解,test_loss /= len(test_loader) # loss function already averages over batch size。試了一下輸出,就發現len(test_loader)指的是進行一次data pass,會將所有的資料分割成多少份的mini-batch。然後配合原來的一行程式碼,在每次mini-batch計算loss的時候,test_loss += F.nll_loss(output, target).data[0]

,就可以推出原始碼在求解generalization error的邏輯是這樣的:

  1. 將一個data pass分成幾個mini-batch
  2. 每一個mini-batch,F.nll_loss(output, target).data[0]的loss value並不是整個mini-batch的loss,而是average loss,有一個預設size_average=True的引數(後面會用到)
  3. 進行一次data pass之後,就可以將每一個mini-batch average loss求和
  4. 將loss之和再除以mini-batch的數量,就得到最後的data point average loss

所以到這裡就能知道,這裡有一個隱藏的bug:這裡假設了我每一個mini-batch size是一樣的,所以才能用這樣求平均的方式。但實際上,最後一個mini-batch是很難正好“滿上”的。

更為精確求解loss的方法是,每一個mini-batch loss不算平均,而直接求和。最後除以所有data point的個數。大概程式碼如下:

for each mini-batch:
    ...
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    ...

...
test_loss /= len(test_loader.dataset)
...

第一次給pub repo做pull request,感覺挺不錯。PyTorch還有很多坑沒填,大家感興趣的可以慢慢填。我看到很多人直接把原bug照搬到自己的repo,也不知道有沒有發現這個問題。其實只要深究test_loss /= len(test_loader)這一行就能發現。

附:因為我pull request已經被merge到code base裡,感興趣的朋友可以移步到github. 類似的問題在mnist_hogwild也有,成功merge github.