TensorFlow中的一個重要ops---MatMul的實現(二)
上面一篇文章我們主要介紹了MatMul這個ops的正向計算的實現,我們知道模型的最後一步是計算優化模型的引數,而一般採用的方法是梯度下降法,所以每個ops節點不僅要實現一個正向的計算節點,還要實現一個反向的梯度計算節點。
關於反向計算的結點官網有如下一段介紹:
Implement the gradient in Python
Given a graph of ops, TensorFlow uses automatic differentiation (backpropagation) to add new ops representing gradients with respect to the existing ops (see
給定一個 Ops 組成的圖, TensorFlow 使用自動微分 (反向傳播) 來新增新的 Ops 以表示梯度運算, 同時不影響已有的 Ops (參見梯度運算). 為了使自動微分能夠與新的
Ops 協同工作, 必須註冊一個梯度函式, 從 Ops 的輸入計算梯度, 並返回代表梯度值的輸出.數學上, 如果一個 Ops 計算 y = f(x) , 註冊的梯度 Ops 通過以下鏈式法則, 將 ∂L/∂y的梯度運算轉化為 ∂L/∂x
總結起來就是一個圖無非是如下結構的遞迴堆疊:
L=g(f(x))
因為TensorFlow的圖有個特點上個節點的輸出是下個節點的輸入,這個不正滿足了鏈式法則的要求。對於上式,假設y=f(x),我們要求的是 z 相對於 x 的梯度(求出梯度就可以利用梯度下降法計算下一步的變數值了),也就是∂L/∂x的值,只需要先分別求出第一個ops節點的梯度∂y/∂x,和第二個ops節點的梯度 ∂L/∂y,依據鏈式法則自然有:
加上一點想象力,假設L=g(y)也是一個符合函式,那不就是可以擴充套件成有很多節點的複雜的圖了。
接下來我們看下MatMul梯度節點的實現:程式碼在:python/ops/math_grad.py檔案裡面是一個Python的實現,
@ops.RegisterGradient("MatMul")
def _MatMulGrad(op, grad):
"""Gradient for MatMul."""
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
a = math_ops.conj(op.inputs[0])
b = math_ops.conj(op.inputs[1])
if not t_a and not t_b:
grad_a = math_ops.matmul(grad, b, transpose_b=True)
grad_b = math_ops.matmul(a, grad, transpose_a=True)
elif not t_a and t_b:
grad_a = math_ops.matmul(grad, b)
grad_b = math_ops.matmul(grad, a, transpose_a=True)
elif t_a and not t_b:
grad_a = math_ops.matmul(b, grad, transpose_b=True)
grad_b = math_ops.matmul(a, grad)
elif t_a and t_b:
grad_a = math_ops.matmul(b, grad, transpose_a=True, transpose_b=True)
grad_b = math_ops.matmul(grad, a, transpose_a=True, transpose_b=True)
return grad_a, grad_b
我們首先看
@ops.RegisterGradient("MatMul")
這個是表示這個申明的是一個梯度
然後我們看到,返回值有兩個:grad_a,grad_b,為什麼呢,因為MatMul(A ,B)這個ops節點函式的引數有兩個,A和B,他也不知道那個是變數那個是定值,所以就對兩個變數都求了梯度,然後可以選擇利用哪個。
現在那問題就轉化為對於函式 y=A*B,分別求y相對於A矩陣和B矩陣的導數,
矩陣求導有如下公式:
Y = A * X --> DY/DX = A'
Y = X * A --> DY/DX = A
Y = A' * X * B --> DY/DX = A * B'
Y = A' * X' * B --> DY/DX = B * A'
更多關於矩陣求導可以參考下文:
http://blog.csdn.net/wbgxx333/article/details/22992977
那現在再來看上面的求導結果就非常清楚了,而且應該說還是比較簡單的