1. 程式人生 > 實用技巧 >FastBERT——自蒸餾模型與適應性調整推斷時間技術

FastBERT——自蒸餾模型與適應性調整推斷時間技術

前言

這幾天被分配到的任務是調研現有的幾種基於 BERT 的蒸餾模型,所以花了些時間瞭解了下 DistilBERT,TINYBERT 以及 FastBERT.

自從 BERT 釋出以來,各種基於 BERT 的改良版本(如 RoBERTa)層出不窮,模型效果也有著不斷的提升,但礙於這些模型過於巨大,需要大量的計算資源,在工業中的應用受限。所以,大家都在想方設法地對 BERT 進行“瘦身”,期望經瘦身後的模型能帶來更快的推斷速度,並大部分地保留原始 BERT 的推斷效能。為了達到此目的,研究者們分成了多個方向:量化(quantinization)、權重剪枝(weights pruning)和知識蒸餾(knowledge distillation,KD)。

知識蒸餾(以下簡稱為 KD)一般需要兩個模型,一大一小,大的充當老師的角色,而小的則作為學生學習老師的行為,目標就是希望學生的小腦瓜子能夠容納住老師的畢生所學,獨當一面(雖然能力可能會比老師差點,但夠用就行)。對於現今 NLP 領域中的 KD,老師和學生模型大多都是基於 Transformers 的模型,一般將 BERT 作為老師,而另一個小的基於 Transformers 的模型作為學生。

通過了解,我發現 FastBERT 更有趣些,它採用 自蒸餾 (self-distilling)技術,即整個蒸餾過程僅需要單個模型,該模型即是老師又是學生,與以往的需要兩個模型的蒸餾方案不同(如DistilBERT、TinyBERT),而且它能根據樣本的難易程度 適應性地調整推斷時間

(adaptive inference time)。所以本文接下來重點介紹 FastBERT.

首先,通過論文的名字 《FastBERT: a Self-distilling BERT with Adaptive Inference Time》就能看出 FastBERT 的兩大特性:self-distilling 和 adaptive inference time.

需要注意的是:該論文只使用文字分類任務做實驗,沒有嘗試 token-level 的分類任務(如命名實體識別),所以將 FastBERT 應用到 token-level 的分類不一定適用或適用效果不好(純猜測)。

模型結構

FastBERT's architecture

Figure 1. 模型結構與推斷過程

FastBERT 由兩個成分組成:主幹(backbone) 和 分支(branches)。如圖 1,主幹結構與傳統的 BERT 一致,包含 12 層的 Transformer block,此外再加上一個分類器,取名為 老師分類器(teacher-classifier);分支由多個 學生分類器(student-classifiers)組成,其結構與老師分類器一致,這些分類器被加在每一層 Transformer block 之上,即第 i i i 層 Transformer block 的輸出作為第 i i i 個 student-classifier 輸入。

在 FastBERT 中,主幹和分支(student-classifiers)是分開訓練的,其 訓練過程 如下:

  1. 主幹(除了 teacher-classifier)部分的預訓練
  2. 整體主幹(包含 teacher-classifier)的微調
  3. 分支(student-classifiers)的自蒸餾

前兩步的訓練過程與傳統 BERT 一樣,所以你也可以採用 BERT 家族中的其他模型(如 BERT-WWM、RoBERTa 和 ERNIE)作為主幹。對於第一步,為了方便,我們可以直接在網上下載已公開的預訓練好的 BERT 家族的模型。然後根據任務在自己的標註語料上執行第二步,在這一階段,所有的 student-classifier 還未被用到。

自蒸餾

在自蒸餾階段(訓練的第三步),模型主幹(包括 teacher-classifier)需要被 freeze,即不再進行引數更新。使用 teacher-classifier 的輸出作為高質量的 軟標籤(soft-label)訓練 student-classifier. 由於各個 student-classifier 是相互獨立的,所以可以分別計算它們的預測 p s p_s ps 與 teacher-classifier 輸出的軟標籤 p t p_t pt之間的差距,論文使用 KL 散度度量這兩者之間的距離:
D K L ( p s , p t ) = ∑ i = 1 N p s ( i ) ⋅ log p s ( i ) p t ( j ) D_{KL}(p_s, p_t) = \sum_{i=1}^{N} p_s(i) \cdot \text{log} \frac{p_s(i)}{p_t(j)} DKL(ps,pt)=i=1Nps(i)logpt(j)ps(i)

其中, N N N 為分類類別的個數。將所有 student-classifier 與 teacher-classifier 的 KL 散度之和作為自蒸餾的訓練損失:
L o s s ( p s 0 , … , p s L − 2 , p t ) = ∑ i = 0 L − 2 D K L ( p s i , p t ) Loss(p_{s_{0}},\dots,p_{s_{L-2}}, p_t) = \sum_{i=0}^{L-2}D_{KL}(p_{s_i}, p_t) Loss(ps0,,psL2,pt)=i=0L2DKL(psi,pt)
其中, L L L 為主幹 Transformer block 的層數。

由於自蒸餾使用 teacher-classifier 的輸出作為標籤,因此該階段也可以使用 未標註的語料進行訓練,這也是 FastBERT 的一大亮點之一。

適應性推斷

為每層 Transformer block 配置一個 student-classifier 的好處就是 FastBERT 可以根據樣本的複雜度動態地調整 Transformer block 的計算層數。對於每一層,分別為每個樣本判斷此時的推斷結果是否足夠可信以終止往後的運算。給定一個輸入,student-classifier 輸出的 不確定性(uncertainty)計算方式如下:
U n c e r t i a n t y = ∑ i = 1 N log p s ( i ) log 1 N Uncertianty = \frac{\sum_{i=1}^{N} \text{log} p_s(i)}{\text{log} \frac{1}{N}} Uncertianty=logN1i=1Nlogps(i)

對於該不確定性,論文作者作了一個假設:不確定性約定,準確率越高(the Lower the Uncertainty, the Higher the Accuracy).

對於一個樣本,如果其在第 i i i 個 student-classifier 上得出的不確定性高於所設定的 閾值 時,那麼該樣本會繼續傳入到往後的 Transformer block 和 student-classifier 中進行運算,反之該樣本則不需要涉及進一步的運算,即就拿當前第 i i i 個 student-classifier 的輸出作為該樣本最終的輸出。

論文中將這裡的閾值定義為 Speed,其與 Uncertainty 的取值均在 0 到 1 之間。很明顯,Speed 越高,能一直堅持到更高層的樣本就越少,則總體的推斷速度就會更快,反之亦反。因此,可將 Speed 看作為推斷速度與準確率之間的權衡。根據 Speed 的不同取值,理論上 FastBERT 的速度可以提升至 BERT 的 1~12 倍。

一些分析

你可能會有疑問,給每層 Transformer block(除了最後一層)都加上一個 student-classifier,難道不會增加模型的計算量,使得模型更加緩慢嗎?論文中展示了 Transformer block 和 student-classifier 各自需要的浮點計算數量,如圖 2,可看出,相較於 Transformer block,student-classifier 的計算需求小到可以忽略。
在這裡插入圖片描述

Figure 2. 浮點計算量

再看一下在 Book review 資料集上,樣本被執行的層數的分佈圖:
在這裡插入圖片描述

Figure 3. 樣本執行層數的分佈

雖然 Speed 取值的高低會造成不同的分佈情況,但可以確定的是,大部分樣本都不會堅持到最後幾層,這也說明了 FastBERT 確實可以提升推斷速度。

對於 FastBERT 的效能可看下錶:
在這裡插入圖片描述

Figure 4. 效能對比

為了提升速度,FastBERT 的效能對比 BERT 幾乎沒有損失,這是很難得的。

思考

其實瞭解了 FastBERT 的結構後就能明白為什麼論文作者只在文字分類這種簡單的分類問題上做實驗,主要是考慮到 uncertainty 的計算問題。對於文字分類,每個樣本有且僅對應一個分類任務的 uncertainty,但對於 token-level 的分類問題(如 命名實體識別),每個樣本對應著 L L L (token 的個數)個分類任務的 uncertainty,這時如何計算樣本層面上的 uncertainty 就可以有不同選擇了,例如,可以每個 token 的 uncertainty 取平均或取最大。或許 FastBERT 本質上並不適合 token-level 的分類問題,這還有待實驗證實。

參考源