1. 程式人生 > >pytorch學習筆記(十七):python 端擴充套件 pytorch

pytorch學習筆記(十七):python 端擴充套件 pytorch

pytorch 雖然提供了很多的 op 使得我們很容易的使用。但是當已有的 op 無法滿足我們的要求的時候,那就需要自己動手來擴充套件。 pytorch 提供了兩種方式來擴充套件 pytorch 的基礎功能。

  • 通過繼承 autograd.Function
  • 通過 C 來擴充套件

本篇部落格主要介紹 繼承 autograd.Function 來擴充套件 pytorch

繼承 autograd.Function 的 子類 只需要 實現兩個 靜態方法:

  • forward : 計算 op 的前向過程.
    • 在執行 forward 之前,Variable 引數已經被轉換成了 Tensor
    • forward 的形參可以有預設引數,預設引數可以是任意 python 物件。
    • 可以返回任意多個 Tensor
    • 裡面可以使用任何 python 操作,但是 return 的值必須是 Tensor !!!
  • backward : 計算 梯度,
    • forward 返回幾個 值, 這裡就需要幾個 形參,還得外加一個 ctx
    • forward 有幾個 形參(不包含 ctx) ,backward 就得返回幾個值。
    • bacward 實參也是 Variable
    • backward 返回的得是 Variable

一個 Demo(來自官網)

class LinearFunction(Function)
:
# forward 和 backward 都得是 靜態方法!!!!! @staticmethod # bias 是個可選引數,有個 預設值 None def forward(ctx, input, weight, bias=None): # input,weight 都已經變成了 Tensor # 用 ctx 把該存的存起來,留著 backward 的時候用 ctx.save_for_backward(input, weight, bias) output = input.mm(weight.t()) if
bias is not None: output += bias.unsqueeze(0).expand_as(output) return output # 由於 forward 只有一個 返回值,所以 backward 只需要一個引數 接收 梯度。 @staticmethod def backward(ctx, grad_output): # grad_output 是 Variable 型別。 # 在開頭的地方將儲存的 tensor 給 unpack 了 # 然後 給 所有應該返回的 梯度 以 None 初始化。 # saved_variables 返回的是 Variable!!! 不是 Tensor 了。 input, weight, bias = ctx.saved_variables grad_input = grad_weight = grad_bias = None # needs_input_grad 檢查是可選的。如果想使得 程式碼更簡單的話,可以忽略。 # 給不需要梯度的 引數返回梯度 不是一個錯誤。 # 返回值 的個數 需要和 forward 形參的個數(不包含 ctx)一致 if ctx.needs_input_grad[0]: grad_input = grad_output.mm(weight) if ctx.needs_input_grad[1]: grad_weight = grad_output.t().mm(input) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0).squeeze(0) # 梯度的順序和 forward 形參的順序要對應。 return grad_input, grad_weight, grad_bias

關於 ctx

  • save_for_backward 只能存 tensor, None, 其餘都不能存。
  • save_for_backward 只儲存 forward 的實參,或者 forward 的返回值。

上面就是繼承 Function 的全過程,然後該怎麼使用呢?

# input, weight, 是 Variable
def linear(input, weight, bias=None):
    # 一定是要 通過呼叫 apply 來用的。 Function.apply 中估計做了不少事情。
    return LinearFunction.apply(input, weight, bias)

也可以將 LinearFunction 封裝到 nn.Module 裡面,以便更簡單的使用。

檢查梯度計算是否正確

pytorch 提供了一個簡單的 介面用來檢查 定義的 梯度計算是否正確

from torch.autograd import gradcheck
# Check gradients computed via small finite differences against analytical gradients

# 檢查的是 inputs 中 requires_grad=True 的梯度,
# 一定要記得 double() 一下!!!!!!
input = (Variable(torch.randn(20, 20).double(), requires_grad=True),
             Variable(torch.randn(30, 20).double(), requires_grad=True),)
test = gradcheck(LinearFunction.apply, input, eps=1e-6, atol=1e-4)
# 如果通過,最後會列印一個 True
print(test)

總結

  • forward 的形參是 Tensorreturn 的也是 Tensor
  • backward 的形參是 Variablereturn 也需要是 Variable
  • gradcheck 的時候,記得將 Tensor 的型別轉成 double, 使用 float 會導致檢查失敗。

GlobalMaxPool例子

class GlobalMaxPool(Function):
    @staticmethod
    def forward(ctx, inputs):
        bs, c, h, w = inputs.size()
        flatten_hw = inputs.view(bs, c, -1)
        max_val, indices = torch.max(flatten_hw, dim=-1, keepdim=True)
        max_val = max_val.view(bs, c, 1, 1)
        ctx.save_for_backward(inputs, indices)
        # 只有返回 indices, 才讓 save_for_backward。。。 迫不得已。
        return max_val, indices

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_max_val, grad_indices):
        inputs, indices = ctx.saved_variables

        bs, c, h, w = inputs.size()
        grad_inputs = inputs.data.new().resize_as_(inputs.data).zero_().view(bs, c, -1)
        grad_inputs.scatter_(-1, indices.data,
                             torch.squeeze(grad_max_val.data).contiguous().view(bs, c, 1))
        grad_inputs = grad_inputs.view_as(inputs.data)

        return Variable(grad_inputs, volatile=grad_max_val.volatile)


def global_max_pool(input):
    return GlobalMaxPool.apply(input)


if __name__ == '__main__':
    in_ = Variable(torch.randn(2, 1, 3, 3).double(), requires_grad=True)
    res, _ = global_max_pool(in_)
    # print(res)

    res.sum().backward()
    res = gradcheck(GlobalMaxPool.apply, (in_,))
    print(res)