1. 程式人生 > 其它 >座標下降法

座標下降法

座標下降法

本文講解如何使用座標下降法求解最小二乘問題。

原理

假設 \(A\in \mathbb{R}^{N\times K}\)\(b\in \mathbb{R}^N\),求 \(x = [x_1, \cdots, x_K]^T\in \mathbb{R}^K\),使得

\[\lVert b - Ax\rVert_2 \]

極小化。

座標下降法的想法是一次只更新一個分量,假設更新第 \(k\) 個分量 \(x_k\),即要求一個 \(s\)

\[x_k \leftarrow x_k + s \]

用以更新 \(x_k\)。怎麼選取這個 \(s\) 呢?當然是選可以使損失函式取得最小值的那個 \(s\)

記更新後的損失函式為

\[g^k(s) = \lVert b - A[x_1, \cdots, x_k+s, \cdots, x_K]^T\rVert_2^2 = \sum_{i=1}^N \bigg( e_i - sA_{ik} \bigg)^2 \]

其中 \(e=b-Ax\in \mathbb{R}^N\)。由於 \(g^k\) 是二次函式,Taylor 展開後與其自身相等,即有

\[g^k(s) = g^k(0) + (g^k)'(0) s+ \frac12 (g^k)''(0)s^2, \]

其極小點為

\[s^{*} = -\frac{(g^k)'(0)}{(g^k)''(0)}. \]

因此可以更新 \(x_k\)

,

\[x_k \leftarrow x_k + s^{*}. \]

另外,簡單計算可知一階導數和二階導數在 \(0\) 處的取值為

\[\left\{ \begin{aligned} (g^k)'(0) &= -2\sum_{i=1}^{N}e_iA_{ik} \\ (g^k)''(0) &= 2 \sum_{i=1}^{N}A_{ik}^2 \end{aligned} \right. \]

虛擬碼

上面講了座標下降法的原理,寫成虛擬碼如下:

輸入:係數矩陣 \(A\in \mathbb{R}^{N\times K}\),標籤向量 \(b\in \mathbb{R}^N\)


過程
初始化 \(x\in \mathbb{R}^K\)
計算初始誤差:\(e = b - Ax\)
提前計算每個分量處的二階導數:\(h (k) = 2\sum_{i=1}^{N}A_{ik}^2\)
while 達到收斂條件:
選取 \(k \in \{1, ..., K\}\) :
計算一階導數:\(g = -2\sum_{i=1}^Ne_i A_{ik}\)
計算極小點:\(s^{*} = \frac{g}{h (k)}\)
座標更新:\(x_k = x_k + s^{*}\)
更新誤差:\(e_i = e_i - s^{*} A_{ik}, \quad i=1, \cdots, N\)
輸出\(x\in \mathbb{R}^K\)

分析

  • 複雜度:座標下降演算法單輪迭代的複雜度為 \(O(KN)\),設一共迭代了 \(T\) 次,則總複雜度為 \(O(TKN)\)
  • 上述實現中,每輪迭代時座標的選取沒有確定,最簡單的方式是按順序迴圈更新,也可以隨機選取。另外,也可以採用貪心的策略選取(有時間寫一下)。
  • 對於非線性的問題,也可類似處理。