用張量廣播機制實現神經網路反向傳播
阿新 • • 發佈:2021-02-22
### 正向傳播
要想了解反向傳播,先要了解正向傳播:正向傳播的每一步是,用一個或很多輸入生成一個輸出。
### 反向傳播
反向傳播的作用是計算模型引數的偏導數。再具體一點,反向傳播的每一個step就是:已知正向傳播的輸入本身,和輸出的偏導數,求出每個輸入的偏導數的過程。
反向傳播既簡單,又複雜:
* 它的原理很簡單:鏈式法則求偏導。
* 它的公式又很複雜:因為它的公式看起來真的很複雜。
### 模型的引數
反向傳播就是計算模型的引數的偏導數,所以介紹一下模型的引數:
* 模型裡有很多引數,引數的本質是張量,可以把張量看成多維陣列,也可以把張量看成[一顆樹](https://www.cnblogs.com/duck-and-duck/p/14283292.html)。
* 張量有形狀,張量的偏導數是一個`同樣形狀`的張量。
### 線性函式的反向傳播
線性函式就是 ```y = wx + b```,我們輸入x,w,和 b 就能得到y。y是我們算出來的,這個算y的過程就是正向傳播。
我們規定字母后面加 ```.g``` 表示偏導數,如 ```y.g``` 就是y的偏導數,```w.g``` 就是w的偏導數。
那麼我們的目的,就是根據 ```x```, ```w```, ```b``` 和 ```y.g``` 的值,分別算出 ```w```,```x```,和```b```的偏導數,而這個過程,就是反向傳播。
為了便於說明,我們假設了每個變數的形狀: x(1000, 784), w(784, 50), b(50), y(1000, 50)。
#### 計算 ```x.g```
```y = wx + b``` 對 `x` 求偏導 得 `w`,即我們要用 `w` 和 `y.g` 計算出 `x.g`。
`w` 的形狀是 (784, 50),`y.g`的形狀跟y相同,是(1000, 50),如何用這兩個形狀湊出 `x.g` 的(1000, 784)?
emmm,很簡單,就是這樣,然後那樣,就行了。看玩笑的。。其實就是 ```y.g``` 中間加一維,變成 (1000, 1, 50) ,然後再跟 `w` 搞一下,得到一個 (1000, 784, 50) 的形狀,再把最後一維消去,就得到 (1000, 784) 的形狀了。
即:
```x.g = (y.g.unsqueeze(1) * w).sum(dim=-1)```
#### 計算 `w.g`
同理咯,```y = wx + b``` 對 `w` 求偏導 得 `x`,即我們要用 `x` 和 `y.g` 計算出 `w.g`。
`x 的形狀是 (1000, 784),`y.g`的形狀跟y相同,是(1000, 50),如何用這兩個形狀湊出 `w.g` 的(784, 50)?
先將 `x` 最後加一維,變成 (1000, 784, 1),再將 `y.g` 中間加一維,變成 (1000, 1, 50),這倆搞一下,變成 (1000, 784, 50),再把開頭的那一維消去,就變成 (784, 50)了。
即:
```w.g = (x.unsqueeze(-1) * y.g.unsqueeze(1)).sum(dim=0)```
#### 計算 `b.g`
```y = wx + b``` 對 `b` 求偏導 得常數 `1`,所以直接用形狀為(1000, 50)的`y.g`來湊出形狀為(50)的`b.g`就可以了。
那麼就非常簡單了,直接把(1000, 50)消去最開始的那一維就能得到(50),即:
```b.g = y.g.sum(0)```
#### 線性函式的反向傳播程式碼
已知線性函式的輸入是 `inp`,輸出是 `out`,計算過程用到的兩個引數是 `w`和`b`,則反向傳播的程式碼如下:
```python
def back_lin(inp, w, b, out):
inp.g = (out.g.unsqueeze(1) * w).sum(dim=-1)
w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(dim=0)
b.g = out.g.sum(0)
```
### relu函式的反向傳播
relu函式表示起來很簡單,就是 `max(x, 0)`,但是在 pytorch 中這樣寫是行不通的,所以用這面這個函式表示:
```python
def relu(x):
return x.clamp_min(0)
```
其反向傳播表示為:
```python
def back_relu(inp, out):
return (inp > 0).float() * out.g
```
### mse函式的反向傳播
mse函式用程式碼表示為:
```python
def mse(pred, target):
return (pred.squeeze(dim=-1)-target).pow(2).mean()
```
其反向傳播則是:
```python
def back_mse(pred, target):
return 2. * (pred.squeeze(dim=-1) - target).unsqueeze(dim=-1) / pred.shape[0]
```
### 測試
假設我們的模型結果為:輸入一個x,進行一次線性變換,再經過一次relu,然後再經過一次線性變換得到結果。
先隨機生成 輸入、輸出和各個引數:
```python
# 偽造輸入和答案
import torch
torch.manual_seed(0)
input_ = torch.randn(1000, 784).requires_grad_(True) # 輸入
target = torch.randn(1000) # 答案
# 建立其它引數
w1 = torch.randn(784, 50).requires_grad_(True)
b1 = torch.randn(50).requires_grad_(True)
w2 = torch.randn(50, 1).requires_grad_(True)
b2 = torch.randn(1).requires_grad_(True)
```
正向傳播得到模型的輸出:
```python
l1 = input_ @ w1 + b1
l2 = relu(l1)
output = l2 @ w2 + b2
loss = mse(output, target)
```
反向傳播:
```python
back_mse(output, target)
back_lin(l2, w2, b2, output)
back_relu(l1, l2)
back_lin(input_, w1, b1, l1)
```
此時 `w1.g`,`b1.g`和 `w2.g`,`b2.g`均已求出。
然後用pytorch自帶的反向傳播求一下梯度:
```python
# 先儲存一下手動求的梯度
w1g = w1.g.clone()
b1g = b1.g.clone()
w2g = w2.g.clone()
b2g = b2.g.clone()
input_ = input_.clone().requires_grad_(True)
w1 = w1.clone().requires_grad_(True)
b1 = b1.clone().requires_grad_(True)
w2 = w2.clone().requires_grad_(True)
b2 = b2.clone().requires_grad_(True)
l1 = input_ @ w1 + b1
l2 = relu(l1)
output = l2 @ w2 + b2
loss = mse(output, target)
loss.backward()
```
此時對比一下我們手動求得的梯度和呼叫系統函式求得的梯度,發現二者是相等的:
```python
def is_same(a, b):
return (a - b).max() < 1e-4
is_same(w1g, w1.grad), is_same(b2g, b2.grad), is_same(w2g, w2.grad), is_same(b2g, b2.grad)
"""輸出
(tensor(True), tensor(True), tensor(True), tensor(True))
"""
```
### 總結
藉助簡單的求導和張量的廣播機制,就可以推導實現神經網路的反向