(14)pytorch之hook函式
阿新 • • 發佈:2020-12-24
一、Hook 函式(不改變主體,實現額外功能)
1 torch.Tensor.register_hook(hook(這個引數為我們自己定義的hook函式))–針對tensor的hook
自己寫一個反向傳播hook函式,來獲取中間變數的梯度, 僅有一個輸入引數(梯度)
eg
w=t.tensor([1.],requires_grad=True)
x=t.tensor([2.],requires_grad=True)
a=t.add(w,x)
b=t.add(w,1)
y=t.mul(a,b)
a_grad=list()
def grad_hook (grad):#手動定義一個hook函式,引數為梯度
a_grad.append(grad)
handle=a.register_hook(grad_hook)#將a張量掛上hook
y.backward()
print(w.grad,x.grad,a.grad,b.grad,y.grad)
print(a_grad)
直接計算中,中間變數的grad被釋放了,但我門根據hook儲存的grad還可以找到
還可以用hook函式修改梯度
def grad_hook(grad):#將grad放大六倍 grad*=2 return grad*3 handle=w.register_hook(grad_hook) y.backward
2.Module.register_forward_hook:註冊module的前向傳播hook函式–用於module
定義hook函式,
引數:module:當前網路層
input:當前網路層輸入資料
output:當前網路層輸出資料
eg
class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv1=nn.Conv2d(1,2,3) self.pool1=nn.MaxPool(2,2) def forward(self,x): x=self.conv1(x) x=self.pool1(x) return x def forward_hook(module,data_input,data_output): fmap_block.append(data_output) input_block.append(data_input) #初始化網路及引數 net=Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_(0) #註冊hook fmap_block=list() input_block=list() net.conv1.register_forward_hook(forward_hook)#注意,這裡是hook繫結在卷積層,而不是net中 #生成資料 fake_img=t.ones((1,1,4,4))#batch size*channel*H*W output=net(fake_img)#處理至卷積層時,module的forward其實是在call函式中,先進行pre_hook,在forward,在hook,再backward_hook.其實是四步。這時,把這三個hook功能完成了 loss_fnc=nn.L1Loss() target=torch.randn_like(output) loss=loss_fnc(target,output) loss.backward()
3.Module.register_forward_pre_hook:註冊module的前向傳播前的hook函式–用於module
自己定義的hook主要引數:moudle:當前網路層
input:當前網路層輸入資料
4.Module.register_backward_hook:註冊module反向傳播的hook函式–用於module
主要引數:module當前網路層
grad_input:當前網路層輸入資料梯度
grad_output:當前網路層輸出資料梯度