1. 程式人生 > 實用技巧 >N0104-計算圖與動態圖機制

N0104-計算圖與動態圖機制

本章程式碼

計算圖

深度學習就是對張量進行一系列的操作,隨著操作種類和數量的增多,會出現各種值得思考的問題。比如多個操作之間是否可以並行,如何協同底層的不同裝置,如何避免冗餘的操作,以實現最高效的計算效率,同時避免一些 bug。因此產生了計算圖 (Computational Graph)。

計算圖是用來描述運算的有向無環圖,有兩個主要元素:節點 (Node) 和邊 (Edge)。節點表示資料,如向量、矩陣、張量。邊表示運算,如加減乘除卷積等。

用計算圖表示:\(y=(x+w)*(w+1)\),如下所示:

可以看作, \(y=a \times b\) ,其中 \(a=x+w\)\(b=w+1\)

計算圖與梯度求導

這裡求 \(y\)\(w\) 的導數。根複合函式的求導法則,可以得到如下過程。

\[\begin{aligned} \frac{\partial y}{\partial w} &=\frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \\ &=b * 1+a * 1 \\ &=b+a \\ &=(w+1)+(x+w) \\ &=2 * w+x+1 \\ &=2 * 1+2+1=5\end{aligned} \]

體現到計算圖中,就是根節點 yy 到葉子節點 ww 有兩條路徑 y -> a -> wy ->b -> w。根節點依次對每條路徑的孩子節點求導,一直到葉子節點w,最後把每條路徑的導數相加即可。

程式碼如下:

import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
# y=(x+w)*(w+1)
a = torch.add(w, x)     # retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)
# y 求導
y.backward()
# 列印 w 的梯度,就是 y 對 w 的導數
print(w.grad)

結果為tensor([5.])

我們回顧前面說過的 Tensor 中有一個屬性is_leaf標記是否為葉子節點。

在上面的例子中,\(x\)\(w\) 是葉子節點,其他所有節點都依賴於葉子節點。葉子節點的概念主要是為了節省記憶體,在計算圖中的一輪反向傳播結束之後,非葉子節點的梯度是會被釋放的。

程式碼示例:

# 檢視葉子結點
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

# 檢視梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

結果為:

is_leaf:
 True True False False False
gradient:
 tensor([5.]) tensor([2.]) None None None

非葉子節點的梯度為空,如果在反向傳播結束之後仍然需要保留非葉子節點的梯度,可以對節點使用retain_grad()方法。

而 Tensor 中的 grad_fn 屬性記錄的是建立該張量時所用的方法 (函式)。而在反向傳播求導梯度時需要用到該屬性。

示例程式碼:

# 檢視梯度
print("w.grad_fn = ", w.grad_fn)
print("x.grad_fn = ", x.grad_fn)
print("a.grad_fn = ", a.grad_fn)
print("b.grad_fn = ", b.grad_fn)
print("y.grad_fn = ", y.grad_fn)

結果為

w.grad_fn =  None
x.grad_fn =  None
a.grad_fn =  <AddBackward0 object at 0x000001D8DDD20588>
b.grad_fn =  <AddBackward0 object at 0x000001D8DDD20588>
y.grad_fn =  <MulBackward0 object at 0x000001D8DDD20588>

PyTorch 的動態圖機制

PyTorch 採用的是動態圖機制 (Dynamic Computational Graph),而 Tensorflow 採用的是靜態圖機制 (Static Computational Graph)。

動態圖是運算和搭建同時進行,也就是可以先計算前面的節點的值,再根據這些值搭建後面的計算圖。優點是靈活,易調節,易除錯。PyTorch 裡的很多寫法跟其他 Python 庫的程式碼的使用方法是完全一致的,沒有任何額外的學習成本。

靜態圖是先搭建圖,然後再輸入資料進行運算。優點是高效,因為靜態計算是通過先定義後執行的方式,之後再次執行的時候就不再需要重新構建計算圖,所以速度會比動態圖更快。但是不靈活。TensorFlow 每次執行的時候圖都是一樣的,是不能夠改變的,所以不能直接使用 Python 的 while 迴圈語句,需要使用輔助函式 tf.while_loop 寫成 TensorFlow 內部的形式。

參考資料