1. 程式人生 > >周志華等提出RNN可解釋性方法,看看RNN內部都幹了些什麼

周志華等提出RNN可解釋性方法,看看RNN內部都幹了些什麼

選自 ArXiv,作者:Bo-Jian Hou, Zhi-Hua Zhou,機器之心編譯,參與:思源、曉坤。

除了數值計算,你真的知道神經網路內部在做什麼嗎?我們一直理解深度模型都靠裡面的運算流,但對於是不是具有物理意義、語義意義都還是懵懵懂懂。尤其是在迴圈神經網路中,我們只知道每一個時間步它都在利用以前的記憶抽取當前語義資訊,但具體到怎麼以及什麼的時候,我們就無能為力了。在本文中,南京大學的周志華等研究者嘗試利用有限狀態機探索 RNN 的內在機制,這種具有物理意義的模型可以將 RNN 的內部流程展現出來,並幫助我們窺探 RNN 到底都幹了些什麼。

結構化學習(Structure learning)的主要任務是處理結構化的輸出,它不像分類問題那樣為每個獨立的樣本預測一個值。這裡所說的結構可以是圖、序列、樹形結構和向量等。一般用於結構化輸出的機器學習演算法有各種概率圖模型、感知機和 SVM 等。在過去的數十年裡,結構化學習已經廣泛應用於目標追蹤、目標定位和語義解析等任務,而多標籤學習和聚類等很多問題同樣與結構化學習有很強的關聯。

一般來說,結構化學習會使用結構化標註作為監督資訊,並藉助相應的演算法來預測這些結構化資訊而實現優良的效能。然而,隨著機器學習演算法變得越來越複雜,它們的可解釋性則變得越來越重要,這裡的可解釋性指的是如何理解學習過程的內在機制或內部流程。在這篇論文中,周志華等研究者重點關注深度學習模型,並探索如何學習這些模型的結構以提升模型可解釋性。

探索深度學習模型的可解釋性通常都比較困難,然而對於 RNN 等特定型別的深度學習模型,我們還是有方法解決的。迴圈神經網路(RNN)作為深度神經網路中的主要組成部分,它們在各種序列資料任務中有非常廣泛的應用,特別是那些帶有門控機制的變體,例如帶有一個門控的 MGU、帶有兩個門控的 GRU 和三個門控的 LSTM。

除了我們熟悉的 RNN 以外,還有另一種工具也能捕捉序列資料,即有限狀態機(Finite State Automaton/FSA)。FSA 由有限狀態和狀態之間的轉換組成,它將從一個狀態轉換為另一個狀態以響應外部序列輸入。FSA 的轉換過程有點類似於 RNN,因為它們都是一個一個接收序列中的輸入元素,並在相應的狀態間傳遞。與 RNN 不同的是,FSA 的內部機制更容易被解釋,因為我們更容易模擬它的過程。此外,FSA 在狀態間的轉換具有物理意義,而 RNN 只有數值計算的意義。

FSA 的這些特性令周志華團隊探索從 RNN 中學習一個 FSA 模型,並利用 FSA 的天然可解釋效能力來理解 RNN 的內部機制,因此周志華等研究者採用 FSA 作為他們尋求的可解釋結構。此外,這一項研究與之前關於結構化學習的探索不同。之前的方法主要關注結構化的預測或分類結果,這一篇文章主要關注中間隱藏層的輸出結構,這樣才能更好地理解 RNN 的內在機制。

為了從 RNN 中學習 FSA,並使用 FSA 解釋 RNN 的內在機制,我們需要知道如何學習 FSA 以及具體解釋 RNN 中的什麼。對於如何學習 FSA,研究者發現非門控的經典 RNN 隱藏狀態傾向於構造一些叢集。但是仍然存在一些重要的未解決問題,其中之一是我們不知道構造叢集的傾向在門控 RNN 中是否也存在。我們同樣需要考慮效率問題,因為門控 RNN 變體通常用於大型資料集中。對於具體解釋 RNN 中的什麼,研究者分析了門控機制在 LSTM、GRU 和 MGU 等模型中的作用,特別是不同門控 RNN 中門的數量及其影響。鑑於 FSA 中狀態之間的轉換是有物理意義的,因此我們可以從與 RNN 對應的 FSA 轉換推斷出語義意義。

在這篇論文中,周志華等研究者嘗試從 RNN 學習 FSA,他們首先驗證了除不帶門控的經典 RNN 外,其它門控 RNN 變體的隱藏狀態同樣也具有天然的叢集屬性。然後他們提出了兩種方法,其一是高效的聚類方法 k-means++。另外一種方法根據若相同序列中隱藏狀態相近,在幾何空間內也相近的現象而提出,這一方法被命名為 k-means-x。隨後研究者通過設計五個必要的元素來學習 FSA,即字母表、一組狀態、初始狀態、一組接受狀態和狀態轉換,他們最後將這些方法應用到了模擬資料和真實資料中。

對於人工模擬資料,研究者首先表示我們可以理解在執行過程學習到的 FSA。然後他們展示了準確率和叢集數量之間的關係,並表示門控機制對於門控 RNN 是必要的,並且門越少越好。這在一定程度上解釋了為什麼只有一個門控的 MGU 在某種程度上優於其它門控 RNN。

對於情感分析這一真實資料,研究者發現在數值計算的背後,RNN 的隱藏狀態確實具有區分語義差異性的能力。因為在對應的 FSA 中,導致正類/負類輸出的詞確實在做一些正面或負面的人類情感理解。

論文:Learning with Interpretable Structure from RNN


論文地址:arxiv.org/pdf/1810.10…

摘要:在結構化學習中,輸出通常是一個結構,可以作為監督資訊用於獲取良好的效能。考慮到深度學習可解釋性在近年來受到了越來越多的關注,如果我們能重深度學習模型中學到可解釋的結構,將是很有幫助的。在本文中,我們聚焦於迴圈神經網路(RNN),它的內部機制目前仍然是沒有得到清晰的理解。我們發現處理序列資料的有限狀態機(FSA)有更加可解釋的內部機制,並且可以從 RNN 學習出來作為可解釋結構。我們提出了兩種不同的聚類方法來從 RNN 學習 FSA。我們首先給出 FSA 的圖形,以展示它的可解釋性。從 FSA 的角度,我們分析了 RNN 的效能如何受到門控數量的影響,以及數值隱藏狀態轉換背後的語義含義。我們的結果表明有簡單門控結構的 RNN 例如最小門控單元(MGU)的表現更好,並且 FSA 中的轉換可以得到和對應單詞相關的特定分類結果,該過程對於人類而言是可理解的。

本文的方法

在這一部分,我們介紹提出方法的直覺來源和方法框架。我們將 RNN 的隱藏狀態表示為一個向量或一個點。因此當多個序列被輸入到 RNN 時,會積累大量的隱藏狀態點,並且它們傾向於構成叢集。為了驗證該結論,我們展示了在 MGU、SRU、GRU 和 LSTM 上的隱藏狀態點的分佈,如圖 1(a)到(d)所示。

圖 1:隱藏狀態點由 t-SNE 方法降維成 2 個維度,我們可以看到隱藏狀態點傾向於構成叢集。

圖 2 展示了整個框架。我們首先在訓練資料集上訓練 RNN,然後再對應驗證資料 V 的所有隱藏狀態 H 上執行聚類,最後學習一個關於 V 的 FSA。再得到 FSA 後,我們可以使用它來測試未標記資料或直接畫出圖示。再訓練 RNN 的第一步,我們利用了和 [ZWZZ16] 相同的策略,在這裡忽略了細節。之後,我們會詳細介紹隱藏狀態聚類和 FSA 學習步驟(參見原文)。

圖 2:本文提出演算法的框架展示。黃色圓圈表示隱藏狀態,由 h_t 表示,這裡 t 是時間步。「A」是迴圈單元,接收輸入 x_t 和 h_t-1 並輸出 h_t。結構化 FSA 的雙圓圈是接受狀態。總體來說,該框架由三個步驟構成,即訓練 RN 你模型、聚類隱藏狀態和輸出結構化 FSA。

完整的從 RNN 學習 FSA 的過程如演算法 1 所示。我們將該方法稱為 LISOR,並展示了兩種不同的聚類演算法。基於 k-means++的被稱為「LISOR-k」,基於 k-means-x 的被稱為「LISOR-x」。通過利用構成隱藏狀態點的聚類傾向,LISOR-k 和 LISOR-x 都可以從 RNN 學習到良好泛化的 FSA。

實驗結果

在這一部分,我們在人工和真實任務上進行了實驗,並可視化了從對應 RNN 模型學習到的 FSA。除此之外,在兩個任務中,我們討論了我們如何從 FSA 解釋 RNN 模型,以及展示使用學習到的 FSA 來做分類的準確率。

第一個人工任務是在一組長度為 4 的序列中(只包含 0 和 1)識別序列「0110」(任務「0110」). 這是一個簡單的只包含 16 個不同案例的任務。我們在訓練集中包含了 1000 個例項,通過重複例項來提高準確率。我們使用包含所有可能值且沒有重複的長度為 4 的 0-1 序列來學習 FSA,並隨機生成 100 個例項來做測試。

第二個人工任務是確定一個序列是否包含三個連續的 0(任務「000」)。這裡對於序列的長度沒有限制,因此該任務有無限的例項空間,並且比任務「0110」更困難。我們隨機生成 3000 個 0-1 訓練例項,其長度是隨機確定的。我們還生成了 500 個驗證例項和 500 個測試例項。

表 2:分別基於 LISOR-k 和 LISOR-x 方法,當從 4 個 RNN 中學習到的 FSA 在任務「0110」的準確率首次達到 1.0 時,叢集的數量(n_c)。注意這些值是越小越高效,並且可解釋性越好。不同試驗中訓練得到的 RNN 模型使用了不同的引數初始化。

如表 2 所示,我們可以看到在從 MGU 學習到的 FSA 的平均叢集數量總是能以最小的叢集數量達到準確率 1.0。叢集數量為 65 意味著 FSA 的準確率在直到 n_c 為 64 時都無法達到 1.0。每次試驗的最小叢集數量和平均最小叢集數量加粗表示。

表 3:分別基於 LISOR-k 和 LISOR-x 方法,當從 4 個 RNNzho 中學習到的 FSA 在任務「000」的準確率首次達到 0.7 時,叢集的數量(n_c)。注意這些值是越小越高效,並且可解釋性越好。

圖 3:在任務「0110」訓練 4 個 RNN 時學習得到的 FSA 結構圖示。叢集數量 k 由 FSA 首次達到準確率 1.0 時的聚類數量決定。0110 的路由用紅色表示。注意在圖(d)中由 4 個獨立於主要部分的節點。這是因為我們捨棄了當輸入一個符號來學習一個確定性 FSA 時更小頻率的轉換。

圖 4:在任務「000」訓練 4 個 RNN 時學習得到的 FSA 結構圖示。叢集數量 k 由 FSA 首次達到準確率 0.7 時的叢集數量決定。


圖 7:在情感分析任務上訓練的 MGU 學習到的 FSA。這裡的 FSA 經過壓縮,並且相同方向上的相同兩個狀態之間的詞被分成同一個詞類。例如,「word class 0-1」中的詞全部表示從狀態 0 轉換為狀態 1。

表 5:從狀態 0 轉換的詞稱為可接受狀態(即包含積極電影評論的狀態 1),其中大多數詞都是積極的。這裡括號中的數字表示詞來源的 FSA 編號。

表 4:當叢集數量為 2 時情感分類任務的準確率。