1. 程式人生 > >論文筆記系列-Efficient Neural Architecture Search via Parameter Sharing

論文筆記系列-Efficient Neural Architecture Search via Parameter Sharing

結構 obj 接收 論文 method objective 系列 其他 步驟

Summary

本文提出超越神經架構搜索(NAS)的高效神經架構搜索(ENAS),這是一種經濟的自動化模型設計方法,通過強制所有子模型共享權重從而提升了NAS的效率,克服了NAS算力成本巨大且耗時的缺陷,GPU運算時間縮短了1000倍以上。在Penn Treebank數據集上,ENAS實現了55.8的測試困惑度;在CIFAR-10數據集上,其測試誤差達到了2.89%,與NASNet不相上下(2.65%的測試誤差)

Research Objective 作者的研究目標

設計一種快速有效且耗費資源低的用於自動化網絡模型設計的方法。主要貢獻是基於NAS方法提升計算效率,使得各個子網絡模型共享權重,從而避免低效率的從頭訓練。

Problem Statement 問題陳述,要解決什麽問題?

本文提出的方法是對NAS的改進。NAS存在的問題是它的計算瓶頸,因為NAS是每次將一個子網絡訓練到收斂,之後得到相應的reward,再將這個reward反饋給RNN controller。但是在下一輪訓練子網絡時,是從頭開始訓練,而上一輪的子網絡的訓練結果並沒有利用起來。

另外NAS雖然在每個節點上的operation設計靈活度較高,但是固定了網絡的拓撲結構為二叉樹。所以ENAS對於網絡拓撲結構的設計做了改進,有了更高的靈活性。

Method(s) 解決問題的方法/算法

ENAS算法核心

回顧NAS,可以知道其本質是在一個大的搜索圖中找到合適的子圖作為模型,也可以理解為使用單個有向無環圖(single directed acyclic graph, DAG)來表示NAS的搜索空間。

基於此,ENAS的DAG其實就是NAS搜索空間中所有可能的子模型的疊加。

下圖給出了一個通用的DAG示例

技術分享圖片

如圖示,各個節點表示本地運算,邊表示信息的流動方向。圖中的6個節點包含有多種單向DAG,而紅色線標出的DAG則是所選擇的的子圖。

以該子圖為例,節點1表示輸入,而節點3和節點6因為是端節點,所以作為輸出,一般是將而二者合並求均值後輸出。

在討論ENAS的搜索空間之前,需要介紹的是ENAS的測試數據集分別是CIFAR-10和Penn Treebank,前者需要通過ENAS生成CNN網絡,後者則需要生成RNN網絡。

所以下面會從生成RNN和生成CNN兩個方面來介紹ENAS算法。

1.Design Recurrent Cells

本小節介紹如何從特定的DAG和controller中設計一個遞歸神經網絡的cell(Section 2.1)?

首先假設共有\(N\)個節點,ENAS的controller其實就是一個RNN結構,它用於決定

  • 哪條邊需要激活
  • DAG中每個節點需要執行什麽樣的計算

下圖以\(N=4\)為例子展示了如何生成RNN。

技術分享圖片

假設\(x[t]\)為輸入,\(h[t-1]\)表示上一個時刻的輸出狀態。

  • 節點1:由圖可知,controller在節點1上選擇的操作是tanh運算,所以有\(h_1=tanh(X_t·W^{(X)}+h_{t-1}·W_1^{(h)})\)
  • 節點2:同理有\(h_2 = ReLU(h_1·W_{2,1}^{(h)})\)
  • 節點3:\(h_3 = ReLU(h_2·W_{3,2}^{(h)})\)
  • 節點4:\(h_4 = ReLU(h_1·W_{4,1}^{(h)})\)
  • 節點3和節點4因為不是其他節點的輸入,所以二者的平均值作為輸出,即\(h_t=\frac{h_3+h_4}{2}\)

由上面的例子可以看到對於每一組節點\((node_i,node_j),i<j\),都會有對應的權重矩陣\(W_{j,i}^{(h)}\)。因此在ENAS中,所有的recurrent cells其實是在搜索空間中共享這樣一組權重的。

2.1 Design Convolutional Networks

本小節解釋如何設計卷積結構的搜索空間

回顧上面的Recurrent Cell的設計,我們知道controller RNN在每一個節點會做如下兩個決定:a)該節點需要連接前面哪一個節點 b)使用何種激活函數。

而在卷積模型的搜索空間中,controller RNN也會做如下兩個覺得:a)該節點需要連接前面哪一個節點 b)使用何種計算操作。

在卷積模型中,(a)決定 (連接哪一個節點) 其實就是skip connections。(b)決定一共有6種選擇,分別是3*3和5*5大小的卷積核、3*3和5*5大小的深度可分離卷積核,3*3大小的最大池化和平均池化。

下圖展示了卷積網絡的生成示意圖。

技術分享圖片

2.2 Design Convolutional Cell

本文並沒有采用直接設計完整的卷積網絡的方法,而是先設計小型的模塊然後將模塊連接以構建完整的網絡(Zoph et al., 2018)。

下圖展示了這種設計的例子,其中設計了卷積單元和 reduction cell。

技術分享圖片

接下來將討論如何利用 ENAS 搜索由這些單元組成的架構。

假設下圖的DAG共有\(B\)個節點,其中節點1和節點2是輸入,所以controller只需要對剩下的\(B-2\)個節點都要做如下兩個決定:a)當前節點需要與那兩個節點相連 b)所選擇的兩個節點需要采用什麽樣的操作。(可選擇的操作有5種:identity(id,相等),大小為3*3或者5*5的separate conv(sep),大小為3*3的最大池化。)

可以看到對於節點3而言,controller采樣的需要連接的兩個節點都是節點2,兩個節點預測的操作分別是sep 5*5和identity。

技術分享圖片

3.Training ENAS and Deriving Architectures

本小節介紹如何訓練ENAS以及如何從ENAS的controller中生成框架結構。(Section 2.2)

controller網絡是含有100個隱藏節點的LSTM。LSTM通過softmax分類器做出選擇。另外在第一步時controller會接收一個空的embedding作為輸入。

在ENAS中共有兩組可學習的參數:

  • 子網絡模型的共享參數,用\(w\)表示。
  • controller網絡(即LSTM網絡參數),用\(θ\)表示。

而訓練ENAS的步驟主要包含兩個交叉階段:第一部訓練子網絡的共享參數\(w\);第二個階段是訓練controller的參數\(θ\)。這兩個階段在ENAS的訓練過程中交替進行,具體介紹如下:

子網絡模型共享參數\(w\)的訓練

在這個步驟中,首先固定controller的policy network,即\(π(m;θ)\)。之後對\(w\)使用SGD算法來最小化期望損失函數\(E_{m~π}[L(m;w)]\)

其中\(L(m;w)\)是標準的交叉熵損失函數:\(m\)表示根據policy network \(π(m;θ)\)生成的模型,然後用這個模型在一組訓練數據集上計算得到的損失值。

根據Monte Carlo估計計算梯度公式如下:

\[\nabla_w E_{m-~π}(m;θ)[L(m;w)] ≈ \frac{1}{M} \sum_i^M \nabla_wL(m_i;w) \]

其中上式中的\(m_i\)表示由\(π(m;θ)\)生成的M個模型中的某一個模型。

雖然上式給出了梯度的無偏估計,但是方差比使用SGD得到的梯度的方差大。但是當\(M=1\)時,上式效果還可以。

訓練controller參數θ

在這個步驟中,首先固定\(w\),之後通過求解最大化期望獎勵\(E_{m~π}[R(m;w)]\)來更新\(θ\)

導出模型架構

首先使用\(π(m,θ)\)生成若幹模型。

之後對於每一個采樣得到的模型,直接計算其在驗證集上得到的獎勵。

最後選擇獎勵最高的模型再次從頭訓練。

當然如果像NAS那樣把所有采樣得到的子模型都先從頭訓練一邊,也許會對實驗結果有所提升。但是ENAS之所以Efficient,就是因為它不用這麽做,原理繼續看下文。

Evaluation 評估方法

1.在 Penn Treebank 數據集上訓練的語言模型

技術分享圖片
技術分享圖片

2.在 CIFAR-10 數據集上的圖像分類實驗

技術分享圖片

由上表可以看出,ENAS的最終結果不如NAS,這是因為ENAS沒有像NAS那樣從訓練後的controller中采樣多個模型架構,然後從中選出在驗證集上表現最好的一個。但是即便效果不如NAS,但是ENAS效果並不差太多,而且訓練效率大幅提升。

下圖是生成的宏觀搜索空間。

技術分享圖片

ENAS 用了 11.5 個小時來發現合適的卷積單元和 reduction 單元,如下圖所示。

技術分享圖片

Conclusion

ENAS能在Penn Treebank和CIFAR-10兩個數據集上得到和NAS差不多的效果,而且訓練時間大幅縮短,效率大大提升。

技術分享圖片



MARSGGBO?原創





2018-8-7



論文筆記系列-Efficient Neural Architecture Search via Parameter Sharing