1. 程式人生 > 其它 >(14)pytorch之hook函式

(14)pytorch之hook函式

技術標籤:pythonpytorch

一、Hook 函式(不改變主體,實現額外功能)

1 torch.Tensor.register_hook(hook(這個引數為我們自己定義的hook函式))–針對tensor的hook

自己寫一個反向傳播hook函式,來獲取中間變數的梯度, 僅有一個輸入引數(梯度)
在這裡插入圖片描述
eg

w=t.tensor([1.],requires_grad=True)
x=t.tensor([2.],requires_grad=True)
a=t.add(w,x)
b=t.add(w,1)
y=t.mul(a,b)
a_grad=list()

def grad_hook
(grad):#手動定義一個hook函式,引數為梯度 a_grad.append(grad) handle=a.register_hook(grad_hook)#將a張量掛上hook y.backward() print(w.grad,x.grad,a.grad,b.grad,y.grad) print(a_grad)

在這裡插入圖片描述

直接計算中,中間變數的grad被釋放了,但我門根據hook儲存的grad還可以找到

還可以用hook函式修改梯度

def grad_hook(grad):#將grad放大六倍
    grad*=2
    return grad*3
handle=w.register_hook(grad_hook)
y.backward

2.Module.register_forward_hook:註冊module的前向傳播hook函式–用於module

定義hook函式,

在這裡插入圖片描述

引數:module:當前網路層
          input:當前網路層輸入資料
          output:當前網路層輸出資料  

eg

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,2,3)
        self.pool1=nn.MaxPool(2,2)
    def forward(self,x):
        x=self.conv1(x)
        x=self.pool1(x)
        return x
def forward_hook(module,data_input,data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)

#初始化網路及引數    
net=Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_(0)

#註冊hook
fmap_block=list()
input_block=list()
net.conv1.register_forward_hook(forward_hook)#注意,這裡是hook繫結在卷積層,而不是net中

#生成資料
fake_img=t.ones((1,1,4,4))#batch size*channel*H*W
output=net(fake_img)#處理至卷積層時,module的forward其實是在call函式中,先進行pre_hook,在forward,在hook,再backward_hook.其實是四步。這時,把這三個hook功能完成了


loss_fnc=nn.L1Loss()
target=torch.randn_like(output)
loss=loss_fnc(target,output)
loss.backward()

3.Module.register_forward_pre_hook:註冊module的前向傳播前的hook函式–用於module

在這裡插入圖片描述
自己定義的hook主要引數:moudle:當前網路層
input:當前網路層輸入資料

4.Module.register_backward_hook:註冊module反向傳播的hook函式–用於module

在這裡插入圖片描述

主要引數:module當前網路層
grad_input:當前網路層輸入資料梯度
grad_output:當前網路層輸出資料梯度