pytorch學習筆記(十七):python 端擴充套件 pytorch
阿新 • • 發佈:2019-01-05
pytorch
雖然提供了很多的 op
使得我們很容易的使用。但是當已有的 op
無法滿足我們的要求的時候,那就需要自己動手來擴充套件。 pytorch
提供了兩種方式來擴充套件 pytorch
的基礎功能。
- 通過繼承
autograd.Function
- 通過
C
來擴充套件
本篇部落格主要介紹 繼承 autograd.Function
來擴充套件 pytorch
。
繼承 autograd.Function
的 子類 只需要 實現兩個 靜態方法:
forward
: 計算op
的前向過程.
- 在執行
forward
之前,Variable
引數已經被轉換成了Tensor
forward
的形參可以有預設引數,預設引數可以是任意python
物件。- 可以返回任意多個
Tensor
- 裡面可以使用任何
python
操作,但是return
的值必須是Tensor
!!!
- 在執行
backward
: 計算 梯度,
forward
返回幾個 值, 這裡就需要幾個 形參,還得外加一個ctx
。forward
有幾個 形參(不包含ctx
) ,backward
就得返回幾個值。bacward
實參也是Variable
。backward
返回的得是Variable
。
一個 Demo(來自官網)
class LinearFunction(Function) :
# forward 和 backward 都得是 靜態方法!!!!!
@staticmethod
# bias 是個可選引數,有個 預設值 None
def forward(ctx, input, weight, bias=None):
# input,weight 都已經變成了 Tensor
# 用 ctx 把該存的存起來,留著 backward 的時候用
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# 由於 forward 只有一個 返回值,所以 backward 只需要一個引數 接收 梯度。
@staticmethod
def backward(ctx, grad_output):
# grad_output 是 Variable 型別。
# 在開頭的地方將儲存的 tensor 給 unpack 了
# 然後 給 所有應該返回的 梯度 以 None 初始化。
# saved_variables 返回的是 Variable!!! 不是 Tensor 了。
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
# needs_input_grad 檢查是可選的。如果想使得 程式碼更簡單的話,可以忽略。
# 給不需要梯度的 引數返回梯度 不是一個錯誤。
# 返回值 的個數 需要和 forward 形參的個數(不包含 ctx)一致
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
# 梯度的順序和 forward 形參的順序要對應。
return grad_input, grad_weight, grad_bias
關於 ctx
save_for_backward
只能存tensor, None
, 其餘都不能存。save_for_backward
只儲存forward
的實參,或者forward
的返回值。
上面就是繼承 Function
的全過程,然後該怎麼使用呢?
# input, weight, 是 Variable
def linear(input, weight, bias=None):
# 一定是要 通過呼叫 apply 來用的。 Function.apply 中估計做了不少事情。
return LinearFunction.apply(input, weight, bias)
也可以將 LinearFunction 封裝到 nn.Module
裡面,以便更簡單的使用。
檢查梯度計算是否正確
pytorch
提供了一個簡單的 介面用來檢查 定義的 梯度計算是否正確
from torch.autograd import gradcheck
# Check gradients computed via small finite differences against analytical gradients
# 檢查的是 inputs 中 requires_grad=True 的梯度,
# 一定要記得 double() 一下!!!!!!
input = (Variable(torch.randn(20, 20).double(), requires_grad=True),
Variable(torch.randn(30, 20).double(), requires_grad=True),)
test = gradcheck(LinearFunction.apply, input, eps=1e-6, atol=1e-4)
# 如果通過,最後會列印一個 True
print(test)
總結
forward
的形參是Tensor
,return
的也是Tensor
backward
的形參是Variable
,return
也需要是Variable
gradcheck
的時候,記得將Tensor
的型別轉成double
, 使用float
會導致檢查失敗。-
GlobalMaxPool例子
class GlobalMaxPool(Function):
@staticmethod
def forward(ctx, inputs):
bs, c, h, w = inputs.size()
flatten_hw = inputs.view(bs, c, -1)
max_val, indices = torch.max(flatten_hw, dim=-1, keepdim=True)
max_val = max_val.view(bs, c, 1, 1)
ctx.save_for_backward(inputs, indices)
# 只有返回 indices, 才讓 save_for_backward。。。 迫不得已。
return max_val, indices
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_max_val, grad_indices):
inputs, indices = ctx.saved_variables
bs, c, h, w = inputs.size()
grad_inputs = inputs.data.new().resize_as_(inputs.data).zero_().view(bs, c, -1)
grad_inputs.scatter_(-1, indices.data,
torch.squeeze(grad_max_val.data).contiguous().view(bs, c, 1))
grad_inputs = grad_inputs.view_as(inputs.data)
return Variable(grad_inputs, volatile=grad_max_val.volatile)
def global_max_pool(input):
return GlobalMaxPool.apply(input)
if __name__ == '__main__':
in_ = Variable(torch.randn(2, 1, 3, 3).double(), requires_grad=True)
res, _ = global_max_pool(in_)
# print(res)
res.sum().backward()
res = gradcheck(GlobalMaxPool.apply, (in_,))
print(res)