1. 程式人生 > 實用技巧 >PyTorch之nn.Module類與前向傳播函式forward的理解

PyTorch之nn.Module類與前向傳播函式forward的理解

1.nn.Module類理解

pytorch裡面一切自定義操作基本上都是繼承nn.Module類來實現的

方法預覽:

class Module(object):
    def __init__(self):
    def forward(self, *input):
 
    def add_module(self, name, module):
    def cuda(self, device=None):
    def cpu(self):
    def __call__(self, *input, **kwargs):
    def parameters(self, recurse=True):
    
def named_parameters(self, prefix='', recurse=True): def children(self): def named_children(self): def modules(self): def named_modules(self, memo=None, prefix=''): def train(self, mode=True): def eval(self): def zero_grad(self): def __repr__(self): def __dir__(self):
''' 有一部分沒有完全列出來 '''

我們在定義自已的網路的時候,需要繼承nn.Module類,並重新實現建構函式__init__和forward這兩個方法。但有一些注意技巧:

(1)一般把網路中具有可學習引數的層(如全連線層、卷積層等)放在建構函式__init__()中,當然我也可以吧不具有引數的層也放在裡面;

(2)一般把不具有可學習引數的層(如ReLU、dropout、BatchNormanation層)可放在建構函式中,也可不放在建構函式中,如果不放在建構函式__init__裡面,則在forward方法裡面可以使用nn.functional來代替

(3)forward方法是必須要重寫的,它是實現模型的功能,實現各個層之間的連線關係的核心


總結:

torch.nn是專門為神經網路設計的模組化介面。nn構建於autograd之上,可以用來定義和執行神經網路。
nn.Module是nn中十分重要的類,包含網路各層的定義及forward方法
定義自已的網路:
  需要繼承nn.Module類,並實現forward方法。
  一般把網路中具有可學習引數的層放在建構函式__init__()中,
  不具有可學習引數的層(如ReLU)可放在建構函式中,也可不放在建構函式中(而在forward中使用nn.functional來代替)
  只要在nn.Module的子類中定義了forward函式,backward函式就會被自動實現(利用Autograd)
  在forward函式中可以使用任何Variable支援的函式,畢竟在整個pytorch構建的圖中,是Variable在流動。還可以使用if,for,print,log等python語法.
注:Pytorch基於nn.Module構建的模型中,只支援mini-batch的Variable輸入方式

2.forward()函式自動呼叫的理解和分析

最近在使用pytorch的時候,模型訓練時,不需要使用forward,只要在例項化一個物件中傳入對應的引數就可以自動呼叫 forward 函式

自動呼叫 forward 函式原因分析:

利用Python的語言特性,y = model(x)是呼叫了物件model的__call__方法,而nn.Module把__call__方法實現為類物件的forward函式,所以任意繼承了nn.Module的類物件都可以這樣簡寫來呼叫forward函式。

案例:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
 
    layer1 = nn.Sequential()
    layer1.add_module('conv1', nn.Conv(1, 6, 3, padding=1))
    layer1.add_moudle('pool1', nn.MaxPool2d(2, 2))
    self.layer1 = layer1
 
    layer2 = nn.Sequential()
    layer2.add_module('conv2', nn.Conv(6, 16, 5))
    layer2.add_moudle('pool2', nn.MaxPool2d(2, 2))
    self.layer2 = layer2
 
    layer3 = nn.Sequential()
    layer3.add_module('fc1', nn.Linear(400, 120))
    layer3.add_moudle('fc2', nn.Linear(120, 84))
    layer3.add_moudle('fc3', nn.Linear(84, 10))
    self.layer3 = layer3
    

  def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.view(x.size(0), -1) x = self.layer3(x) return x

模型呼叫:

model = LeNet()
y = model(x)

呼叫forward方法的具體流程是:

執行y = model(x)時,由於LeNet類繼承了Module類,而Module這個基類中定義了__call__方法,所以會執行__call__方法,而__call__方法中呼叫了forward()方法

只要定義型別的時候,實現__call__函式,這個型別就成為可呼叫的。 換句話說,我們可以把這個型別的物件當作函式來使用

定義__call__方法的類可以當作函式呼叫(參見:https://www.cnblogs.com/luckyplj/p/13378008.html)

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

總結:當執行model(x)的時候,底層自動呼叫forward方法計算結果

參考文獻:
https://blog.csdn.net/u011501388/article/details/84062483