1. 程式人生 > 程式設計 >pytorch 模型的train模式與eval模式例項

pytorch 模型的train模式與eval模式例項

原因

對於一些含有batch normalization或者是Dropout層的模型來說,訓練時的froward和驗證時的forward有計算上是不同的,因此在前向傳遞過程中需要指定模型是在訓練還是在驗證。

原始碼

[docs] def train(self,mode=True):
  r"""Sets the module in training mode.

  This has any effect only on certain modules. See documentations of
  particular modules for details of their behaviors in training/evaluation
  mode,if they are affected,e.g. :class:`Dropout`,:class:`BatchNorm`,etc.

  Returns:
   Module: self
  """
  self.training = mode
  for module in self.children():
   module.train(mode)
  return self

[docs] def eval(self):
  r"""Sets the module in evaluation mode.

  This has any effect only on certain modules. See documentations of
  particular modules for details of their behaviors in training/evaluation
  mode,etc.
  """
  #該方法呼叫了nn.train()方法,把引數預設值改為false. 增加聚合性
  return self.train(False)

在使用含有BN層,dropout層的神經網路來說,必須要區分訓練驗證

以上這篇pytorch 模型的train模式與eval模式例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。