1. 程式人生 > >模型壓縮之 BinaryNet

模型壓縮之 BinaryNet

1. 動機

深度學習在影象、語音、文字等領域都取得了巨大的成功,推動了一系列智慧產品的落地。但深度模型存在著引數眾多,訓練和 inference 計算量大的不足。目前,基於深度學習的產品大多依靠伺服器端運算能力的驅動,非常依賴良好的網路環境。

很多時候,出於響應時間、服務穩定性和隱私方面的考慮,我們更希望將模型部署在本地(如智慧手機上)。為此,我們需要解決模型壓縮的問題——將模型大小、記憶體佔用、功耗等降低到本地裝置能夠承受的範圍之內。

2. 方法

神經網路具有分散式的特點——特徵表徵和計算都分散於各個層、各個引數。因此,神經網路在結構上天然具有冗餘的特點。冗餘是神經網路進行壓縮的前提。

壓縮模型一般可以有幾種常見的方法:

2.1 使用小模型

設計小模型

可以直接將模型大小做為約束,在模型結構設計和選擇時便加以考慮。對於全連線,使用 bottleneck 是一個有效的手段(如 LSTMP)。HighwayResNetDenseNet 等帶有 skip connection 結構的模型也被用來設計窄而深的網路,從而減少模型整體引數量和計算量。對 CNN 網路,SqueezeNet 通過引入1 x 1的小卷積核、減少 feature map 數量等方法,在分類精度與 AlexNet 相當的前提下,將模型大小壓縮在 1M 以內,而模型大小僅是 Alexnet 的50分之一

。更新的還有 MobileNet、ShuffleNet 等。

模型小型化

一般而言,相比於小模型,大模型更容易通過訓練得到更優的效能。那麼,能否用一個較小的模型,“提煉”出訓練好的大模型的知識能力,從而使得小模型在特定任務上,達到或接近大模型的精度?Knowledge Distilling(e.g. 12)便嘗試解決這一問題。knowledge distilling 將大模型的輸出做為 soft target 來訓練小模型,達到知識“凝練“的效果。實驗表明,distilling 方法在 MNIST 及聲學建模等任務上有著很好的表現。

2.2 利用稀疏性

我們也可以通過在模型結構上引入稀疏性,從而達到減少模型引數量的效果。

裁剪已有模型

將訓練好的模型進行裁剪的方法,至少可以追溯到90年代。 Optimal Brain DamageOptimal Brain Surgeon 通過一階或二階的梯度資訊,刪除不對效能影響不顯著的連線,從而壓縮模型規模。

學習稀疏結構

稀疏性也可以通過訓練獲得。更近的一系列工作(Deep compression: abcHashedNets)在控制模型效能的前提下,學習稀疏的模型結構,從而極大的壓縮模型規模。

2.3 降低運算精度

不同傳統的高效能運算,神經網路對計算精度的要求不高。目前,基本上所有神經網路都採用單精度浮點數進行訓練(這在很大程度上決定著 GPU 的架構設計)。已經發布的 NVIDIA Pascal 架構的最大特色便是原生的支援半精度(half float)運算。在服務端,FPGA 等特殊硬體在許多資料中心得到廣泛應用,多采用低精度(8 bit)的定點運算。

引數量化

除了使用低精度浮點運算(float32, float16)外,量化引數是另一種利用簡化模型的有效方法。
將引數量化有如下二個優勢:
* 減少模型大——將 32 或 16 位浮點數量化為 8 位甚至更少位的定點數,能夠極大減少模型佔用的空間;
* 加速運算——相比於複雜的浮點運算,量化後的定點運算更容易利用特殊硬體(FPGA,ASIC)進行加速。

上面提到的 Deep Compression 使用不同的位數量化網路。Lin 等的工作,在理論上討論上,在不損失效能的前提下,CNN 的最優量化策略。此外,還有量化 CNNRNN 權值的相關工作。

引數二值化

量化的極限是二值化,即每一個引數只佔用一個 bIt。本文討論的正是這個種壓縮模型的方法。

3. BinaryNet

BinaryNet [1] 研究物件是前饋網路(全連線結構或卷積結構)(這方法在 RNN 上並不成功 [4])。這裡,我們更關心權值的二值化對 inference 的精度和速度的影響,而不關心模型的訓練速度(量化梯度以加速模型訓練的工作可以參見 [3])。

前饋模型(卷積可以看成是一種特殊的全連線)可以用如下公式表示:

xk=σ(Wkxk1)

其中,xk 為第k 層的輸入,Wk 為第 k 層的權值矩陣,σ() 為非線性啟用函式。由於 Batch Normalizaiton 的引入,偏置項 b 成為冗餘項,不再考慮。

3.1. 二值化權值和啟用

首先,我們定義取符號操作:

sign(x)={1ifx0,1

在 BinaryNet 中,網路權值為 +1 或 -1,即可以用 1bit 表示,這一點與 BinaryConnect 相同。更進一步,BinaryNet 使用了輸出為二值的啟用函式,即:

σ(x)=sign(x)

這樣,除了第一層的輸入為浮點數或多位定點數外,其他各層的輸入都為 1 bit。

3.2. 訓練

BinaryNet 二值化權值和啟用的思路很容易理解,但關鍵在於,如何有效地訓練網路這樣一個二值網路。

[1] 提出的解決方案是:權值和梯度在訓練過程中保持全精度(full precison),也即,訓練過程中,權重依然為浮點數,訓練完成後,再將權值二值化,以用於 inference。

權值

在訓練過程中,權值為 32 位的浮點數,且取值值限制在 [-1, 1] 之間,以保持網路的穩定性。為此,訓練過程中,每次權值更新後,需要對權值 W 的大小進行檢查,W=max(min(1,W),1)

前向

前向運算時,我們首先得到二值化的權值:Wbk=sign(Wk),k=1,,n
然後,用 Wbk 代替 Wk

xk=σ(BN(Wbkxk1)=sign(BN(Wbkxk1))

其中,BN() 為 Batch Normalization 操作。

後向

根據誤差反傳演算法(Backpropagation,BP),由於 sign() 的導數(幾乎)處處為零,因此,W 通過 BP 得到的誤差 ΔW 為零 ,因此不能直接用來更新權值。為解決這個問題,[1] 採用 straight-through estimator(Section 1.3) 的方法,用 ΔWb 代替 ΔW。這樣,BinaryNet 就可以和序普通的實值網路一樣,使用梯度下降法進行優化。
另外,如果啟用值(二值化之前)的絕對值大於1,相應的梯度也要置零,否則會影響訓練效果。

3.4. 效能

模型精度

BinaryNet 在 MNIST (MLP) ,CIFAR10、SVHN(CNN)上取得了不錯的結果(表1第二列)。

資料集 論文結果 squared hinge loss (同論文) xent loss
MNIST (MLP) 0.96% 1.06% 1.02%
CIFAR10 (CNN) 11.40% 11.92% 11.91%
SVHN (CNN) 2.80% 2.94% 2.82%

表 1 不同資料集上錯誤率

壓縮效果

二值化網路在運算速度、記憶體佔用、能耗上的優勢是顯而易見的,這也是我們對二值化感興趣的原因。[1] 中給出了這方面的一些分析,具體可以參見 [1](Section 3),此處不再贅述。

4. Source Code

4.1 訓練

BinaryNet[1]的作者給出了 theanotorch 兩個版本,兩者之間略有不同。theano 採用確定性(deterministic)的二值化方式,而 torch 是隨機化(stochastic)的二值化,並且 torch 版對 Batch Normalization 操作也進行了離散化處理。具體差異可以參見論文。

根據文章 theano 版本的實現,我們有基於 Keras 的實現。這個版本利用了一個 trick ,實現了梯度的 straight-through estimator。

理想情況下,theano 和 tensorflow 在做 Graph 優化時,應該能夠優化掉這個 trick 帶來的效能開銷,但對於MLP, tensorflow 的後端明顯比 theano 慢(~235s vs. ~195s),但不清楚是否是兩者對 Graph 優化能力差異造成的。

在 MNIST、CIFAR10 和 SVHN 資料集上,基本復現了文章的結果(見表1. 三、四列)

與文章聲稱的不同,目標函式分別選擇交叉熵(crossentropy, xent)與合葉損失(L2-SVM)時,在三個資料集上的效能幾乎沒有判別,甚至交叉熵還要略好一些。

另外,感興趣的讀者可以參考基於 pytorch 的實現

4.2 Inference

正如上面介紹的,BinaryNet 的最大優點是可以 XNOR-計數 運算替代複雜的乘法-加法操作。[1] 給出了二值網路 inference 的基於 CUDA 的 GPU 參考實現。另外,還有基於 CPU 的實現(其基於 tensorflow 的訓練程式碼有些小問題)。[2] 報告了基於 FPGA 的實現及加速效果。

5. 結語

  • 在小型任務上,BinaryNet 完全有希望滿足精度要求。目前手裡沒有真實應用的資料(如語音的靜音檢測),不能進一步驗證可行性。
  • 至於 BinaryNet 在大型任務上的效能,從 [5][6] 報告的 ImageNet 準確率來看, 效能損失還是比較大的。更高的量化精度似是必須的[4][7]。
  • 此外,根據實驗經驗,BinaryNet 的訓練不太穩定,需要較小的學習率,收斂速度明顯慢於實值網路。

References

Further Readings