[PyTorch 學習筆記] 5.2 Hook 函式與 CAM 演算法
阿新 • • 發佈:2020-09-07
> 本章程式碼:
>
> - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py)
> - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_methods.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_methods.py)
> - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/weight_fmap_visualization.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/weight_fmap_visualization.py)
這篇文章主要介紹瞭如何使用 Hook 函式提取網路中的特徵圖進行視覺化,和 CAM(class activation map, 類啟用圖)
# Hook 函式概念
Hook 函式是在不改變主體的情況下,實現額外功能。由於 PyTorch 是基於動態圖實現的,因此在一次迭代運算結束後,一些中間變數如非葉子節點的梯度和特徵圖,會被釋放掉。在這種情況下想要提取和記錄這些中間變數,就需要使用 Hook 函式。
PyTorch 提供了 4 種 Hook 函式。
## torch.Tensor.register_hook(hook)
功能:註冊一個反向傳播 hook 函式,僅輸入一個引數,為張量的梯度。
`hook`函式:
```
hook(grad)
```
引數:
- grad:張量的梯度
程式碼如下:
```
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 儲存梯度的 list
a_grad = list()
# 定義 hook 函式,把梯度新增到 list 中
def grad_hook(grad):
a_grad.append(grad)
# 一個張量註冊 hook 函式
handle = a.register_hook(grad_hook)
y.backward()
# 檢視梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
# 檢視在 hook 函式裡 list 記錄的梯度
print("a_grad[0]: ", a_grad[0])
handle.remove()
```
結果如下:
```
gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])
```
在反向傳播結束後,非葉子節點張量的梯度被清空了。而通過`hook`函式記錄的梯度仍然可以檢視。
`hook`函式裡面可以修改梯度的值,無需返回也可以作為新的梯度賦值給原來的梯度。程式碼如下:
```
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad):
grad *= 2
return grad*3
handle = w.register_hook(grad_hook)
y.backward()
# 檢視梯度
print("w.grad: ", w.grad)
handle.remove()
```
結果是:
```
w.grad: tensor([30.])
```
## torch.nn.Module.register_forward_hook(hook)
功能:註冊 module 的前向傳播`hook`函式,可用於獲取中間的 feature map。
`hook`函式:
```
hook(module, input, output)
```
引數:
- module:當前網路層
- input:當前網路層輸入資料
- output:當前網路層輸出資料
下面程式碼執行的功能是 $3 \times 3$ 的卷積和 $2 \times 2$ 的池化。我們使用`register_forward_hook()`記錄中間卷積層輸入和輸出的 feature map。
``` class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x def forward_hook(module, data_input, data_output): fmap_block.append(data_output) input_block.append(data_input) # 初始化網路 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 註冊hook fmap_block = list() input_block = list() net.conv1.register_forward_hook(forward_hook) # inference fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W output = net(fake_img) # 觀察 print("output shape: {}\noutput value: {}\n".format(output.shape, output)) print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0])) print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0])) ``` 輸出如下: ``` output shape: torch.Size([1, 2, 1, 1]) output value: tensor([[[[ 9.]], [[18.]]]],
``` class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x def forward_hook(module, data_input, data_output): fmap_block.append(data_output) input_block.append(data_input) # 初始化網路 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 註冊hook fmap_block = list() input_block = list() net.conv1.register_forward_hook(forward_hook) # inference fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W output = net(fake_img) # 觀察 print("output shape: {}\noutput value: {}\n".format(output.shape, output)) print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0])) print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0])) ``` 輸出如下: ``` output shape: torch.Size([1, 2, 1, 1]) output value: tensor([[[[ 9.]], [[18.]]]],