PyTorch自定義網路層
PyTorch自定義網路層
這篇部落格關於如何通過自定義function和module來擴充套件 torch.nn和 torch.autograd,module對應於網路中的層。對於淺層網路,我們可以手寫前向傳播和後向傳播過程。但是在深度學習中,網路結構很複雜,前向和後向也很複雜,手寫就變得很困難。幸運的是,在PyTorch中存在自動微分的包,可以用來解決該問題。在使用自動微分包的時候,網路前向傳播會定義一個計算圖(computational graph),圖中的節點是Tensor,兩個節點之間的邊對應了Tensor之間的變換關係函式。有了計算圖的存在,Tensor的梯度計算也變得容易了些。例如,x是一個Tensor,其屬性x.requires_grad = True,那麼x.grad就是一個儲存了這個Tensor x的梯度的一個標量值。
擴充套件 torch.autograd
最基礎的自動求導操作在底層就是作用在兩個張量上。前向傳播函式是從輸入張量到輸出張量的計算過程;反向傳播是輸入輸出張量的梯度(一些標量)並輸出輸入張量的梯度(一些標量)。在pytorch中我們可以很容易地定義自己的自動求導操作,通過繼承torch.autograd.Function並定義forward和backward函式。
往autograd中新增操作需要給每一個操作實現一個新的Function子類。Function是autograd用來計算結果和梯度,儲存操作記錄的。每一個新的function都需要你去實現以下兩個方法:
- forward():前向傳播操作,它可以輸入任意多的引數。任意的python物件都可以。在呼叫之前,記錄了梯度的Tensor引數(如 requires_grad=True)會被轉為不去記錄梯度的引數。你既可以返回單一的Tensor,也可以返回Tensor的元組。
- backward():反向傳播(梯度公式),輸出的引數個數應與輸入的引數個數一樣多,每一個輸出的引數代表著一個輸入引數的梯度。如果你輸入的引數不需要梯度(needs_input_grad是一個布林值元組,表示是否需要梯度計算),或者輸入不是Tensor物件,可以返回None.
下面是torch.nn裡面的一個Linear function的程式碼:
# Inherit from Function
class LinearFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
# ctx在這裡類似self, ctx的屬性可以在backward中呼叫
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
# 在這裡,獲取ctx中儲存的引數
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
為了使用簡單些,我們建議給apply 方法換個名字:
linear = LinearFunction.apply
這裡,給了另一個例子,引數不是Tensor,而是常量,不需要更新:
class MulConstant(Function):
@staticmethod
def forward(ctx, tensor, constant):
# ctx is a context object that can be used to stash information
# for backward computation
ctx.constant = constant
return tensor * constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
你可能想檢查下backward方法計算function的梯度是否準確,你可以和數值近似方法得到的結果來比較:
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
擴充套件torch.nn
計算圖和自動微分在定義複雜網路和求梯度的時候非常好用,但對於大型網路,這個有點太底層。在我們構建網路的時候,經常希望將計算限制在每個層裡面(引數分層更新)。在PyTorch中提供了nn包,定義了一組等價於layer的模組module。一個module接受輸入tensor並得到輸出tensor,同時也會包含可學習的引數。
有時候我們需要運用一些新的且nn中沒有的module,此時就需要自定義自己的module了,自定義的module需要繼承nn.Module且自定義forward函式。其中forward函式可以接受輸入tensor並利用其他模型或者其他自動給求導操作來產生輸出tensor. 但並不需要重寫backward函式,因此nn使用了autograd. 這也就意味著,需要自定義module, 都必須要有對應的autograd函式呼叫其中的backward。
增加一個Module
因為nn經常使用autograd,增加一個新的Module需要實現一個計算梯度的function。例如,我們想實現一個Linear module,而且我們已經在上面實現了它的function。我們要實現兩個function:
- init(optional):需要如kernel_size, 特徵個數等引數,而且初始化引數和buffer。
- forward():例項化一個function並使用它來執行操作。這跟上面的function wrapper很類似。
下面就是Linear module如何實現的:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
self.weight.data.uniform_(-0.1, 0.1)
if bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
Function與Module異同:
Function與Module都可以對pytorch進行自定義拓展,使其滿足網路的需求,但這兩者還是有十分重要的不同:
- Function一般只定義一個操作,因為其無法儲存引數,因此適用於啟用函式、pooling等操作;Module是儲存了引數,因此適合於定義一層,如線性層,卷積層,也適用於定義一個網路;
- Function需要定義三個方法:init, forward, backward(需要自己寫求導公式);Module:只需定義init和forward,而backward的計算由自動求導機制構成;
- 可以不嚴謹的認為,Module是由一系列Function組成,因此其在forward的過程中,Function和Variable組成了計算圖,在backward時,只需呼叫Function的backward就得到結果,因此Module不需要再定義backward;
- Module不僅包括了Function,還包括了對應的引數,以及其他函式與變數,這是Function所不具備的;
- Module 是PyTorch組織神經網路的基本方式。Module 包含了模型的引數以及計算邏輯。Function承載了實際的功能,定義了前向和後向的計算邏輯;
- Module 是任何神經網路的基類,PyTorch中所有模型都必需是 Module的子類。Module 可以套嵌,構成樹狀結構。一個 Module 可以通過將其他 Module 做為屬性的方式,完成套嵌。
- Function 是 PyTorch自動求導機制的核心類。Function是無引數或者說無狀態的,它只負責接收輸入,返回相應的輸出;對於反向,它接收輸出相應的梯度,返回輸入相應的梯度。
- 在呼叫loss.backward()時,使用的是Function子類中定義的backward()函式。