1. 程式人生 > 其它 >Scalable Rule-Based Representation Learning for Interpretable Classification

Scalable Rule-Based Representation Learning for Interpretable Classification

目錄

Wang Z., Zhang W., Liu N. and Wang J. Scalable rule-based representation learning for interpretable classification. In Advances in Neural Information Processing Systems (NIPS), 2021.

傳統的諸如決策樹之類的機器學習方法具有很強的結構性, 也因此具有很好的可解釋性. 和深度學習方法相比, 這類方法比較難以推廣到大規模的問題上, 很重要的一個原因便是, 其離散的引數和結構導致無法利用梯度進行優化. 本文是對利用梯度來優化這些模型的一個嘗試.

主要內容

本文考慮的是上圖(a)中的離散模型, 其接受連續變數\(C_i\)和離散變數\(B_i\):

  1. 通過Binarization Layer 將連續變數\(C_i\)離散化並與\(B_i\)拼接得到輸入\(\bm{u}^{(0)}\);
  2. 對於Logical Layer, 其以\(\bm{u}^{l-1}\)為輸入, 輸出\(\bm{u}^l\), 其包含且\(\bm{r}\)和或\(\bm{s}\)兩個部分:
\[r_i^{(l)} = \bigwedge_{W_{ij}^{(l, 0)} = 1} u_j^{(l-1)}, \\ s_i^{(l)} = \bigvee_{W_{ij}^{(l, 1)} = 1} u_j^{(l-1)}. \\ \]

其中\(W^{(l, 0)}\)

表示\(\bm{r}\)\(\bm{u}\)的鄰接矩陣, 而\(W^{(l, 1)}\)表示\(\bm{s}\)\(\bm{u}\)的鄰接矩陣. 可以發現, Logical Layer中的輸入輸出和權重都是二元的.
3. 最後通過一個線性層進行分類, 需要說明的是, 線性層的權重是連續的.

顯然由於logical layer是離散的, 直接通過梯度更新是辦不到的. 一個自然的想法是用一個連續的版本\(\hat{\mathcal{F}}(X; \theta)\)進行替換, 更新連續的引數\(\theta\)然後獲得下列的離散的版本:

\[\mathcal{F}(X; q(\theta)), \quad q(x) = \mathbb{I}_{x > 0.5}. \]

顯然直接套用這個方法是低效的, 因為訓練過程和離散沒有任何關係, 我們沒法保證離散後的模型依舊是有效的, 此外還有一個問題, 上述離散模型如何匹配到一個連續的版本.

下面是一個有趣的解決方案, 假設\(\hat{W}_{i,j} \in [0, 1]\), 則

\[Conj (\bm{u}, W_i) = \prod_{j=1}^n \bigg\{1 - W_{i,j}(1 - u_j) \bigg\}, \\ Disj (\bm{u}, W_i) = 1 - \prod_{j=1}^n \bigg\{1 - W_{i,j}u_j \bigg\}, \\ \]

便為且和或操作的連續版本.
試想:

\[\begin{array}{ll} & r_i = 1 \\ \Leftrightarrow & \bigwedge_j [u_j^{(l-1)} \vee (1 - W_{ij})] = 1\\ \Leftrightarrow & \prod_j \bigg\{1 - W_{i,j}(1 - u_j) \bigg\} = 1.\\ \end{array} \]

其它情況可以類似推導, 實在是有趣.

但是上述式子在實際中會有一些梯度消失的問題(因為連乘號, 且內部是[0, 1]之間的), 所示在實際使用中, 作者加了一個投影運算元

\[Conj_+ = \mathbb{P}(Conj (\bm{u}, W_i)), \]

其中(這設計都是為了避免梯度消失, 怎麼想到的? 怎麼會往這個方向去想的?)

\[\mathbb{P}(v) = \frac{-1}{-1 + \log (v)}. \]

解決了連續版本的問題, 現在剩下的難啃的地方是如何更新\(\theta\)以保證\(q(\theta)\)也是有意義的.
作者採用如下的梯度更新公式:

\[\theta^{t+1} = \theta^t - \eta \frac{\partial \mathcal{L}(\bar{Y})}{\partial \bar{Y}} \cdot \frac{\partial \hat{Y}}{\partial \theta^t}, \]

其中\(\hat{Y} = \hat{\mathcal{F}}(X; \theta)\), \(\bar{Y} = \mathcal{F}(X; \bar{\theta})\).
作者用了一個嫁接的例子來說明該思想, 即損失關於預測的導數用離散的, 內部的導數用連續的.

我驚訝的是, 這些改動居然work? 太不可思議了.