論文筆記系列-DARTS: Differentiable Architecture Search
Summary
我的理解就是原本節點和節點之間操作是離散的,因為就是從若幹個操作中選擇某一個,而作者試圖使用softmax和relaxation(松弛化)將操作連續化,所以模型結構搜索的任務就轉變成了對連續變量\(α={α^{(i,j)}}\)以及\(w\)的學習。(這裏\(α\)可以理解成the encoding of the architecture)。
之後就是叠代計算\(w\)和\(α\),這是一個雙優化問題,具體處理細節參見3.Approximation
Research Objective
作者的研究目標
在連續域內進行模型搜索,這樣就可以使用梯度下降對模型進行優化。
Problem Statement
- 離散域的結構搜索問題
NAS,ENAS其本質都是在離散空間對模型進行搜索,而文中是這麽diss這些方法的:那些方法 把結構搜索當做在離散域內的黑盒優化問題處理,這就導致了需要采樣大量的模型進行評估才能選出合適的模型,所以計算量都很大。
原文:
An inherent cause of inefficiency for the dominant approaches, e.g. based on RL, evolution, MCTS (Negrinho and Gordon, 2017), SMBO (Liu et al., 2017a) or Bayesian optimization(Kandasamy et al., 2018), is the fact that architecture search is treated as a black-box optimization problem over a discrete domain, which leads to a large number of architecture evaluations required.
- 早期的連續域結構搜索問題
DARTS並不是最早引入連續域搜索的方法,(Saxena and Verbeek, 2016; Ahmed and Torresani, 2017; Shin et al., 2018)也都是在連續域內做的結構搜索,它們主要是對模型結構的特定屬性做微調,例如卷積核大小,分支模式等。但是DARTS和這些方法還是有一些區別的:DARTS可以在豐富的搜索空間中發現具有復雜圖形拓撲的高性能體系結構,而且可以生成RNN和CNN模型。
Method(s)
本節思路如下:
1.首先以一般形式描述搜索空間。
2.之後為搜索空間引入了一個簡單的連續松弛方案(continuous relaxation scheme)1
3.最後提出一個近似方法來使算法在計算上可行且有效。
1.Search Space
根據前人的經驗,本文使用了 Cell作為模型結構搜索的基礎單元。所學習的單元可以堆疊成卷積網絡,也可以遞歸連接形成遞歸網絡。
一個Cell是由\(N\)個有序節點組成的有向非循環圖。每一個節點\(x^{(i)}\)就是一個 latent representation(例如CNN中的feature map),而\(o^{(i,j)}\)表示有向邊\((i,j)\)關於\(x^{(i)}\)的操作。
假設每個Cell有兩個輸入節點和一個輸出節點。對於卷積單元而言,輸入節點定義為前兩層的輸出(Zoph et al., 2017)。對於遞歸單元而言,輸入節點就是當前的輸入和上移時刻的狀態。單元輸出是通過對所有中間節點做 reduction operation(例如concatenation) 得到的。其中每個中間節點表達式如下:
\[x^{(i)} = \sum_{j<i}{o^{(i,j)}(x^{(j)})}\]
\(o^{(i,j)}\)中有一個特殊的操作,即\(zero\)操作,該操作表示兩個節點之間沒有連接。所以學習構建Cell的任務就簡化成了各個edge上的操作。
2.Continuous Relaxation and Optimization
令\(\mathcal{O}\)表示一組候選操作集合(如卷積,最大池化等),而每一個操作用\(o(·)\)表示。
為了使的搜索空間連續,我們將特定操作的分類選擇放寬為所有可能操作的softmax,公式如下:
\[\overline{o}^{(i,j)} = \sum_{o∈\mathcal{O}} \frac{exp(α_o^{(i,j)})}{\sum_{o'∈\mathcal{O}} exp(α_{o'}^{(i,j)})}o(x) \tag{1}\]
其中,一對節點(i,j)的操作混合權重由維度\(|\mathcal{O}|\)的矢量α參數化。
經過上面公式的松弛(relaxation)之後,模型結構搜索的任務就轉變成了對連續變量\(α={α^{(i,j)}}\)的學習,那麽\(α\)即為模型結構的編碼(encoding)如下圖所示。
搜索到最後,我們需要通過將最大可能操作(即\(o^{(i,j)}=argmax_{o∈\mathcal{O}} \,\,α_o^{(i,j)}\))代替混合操作(即\(\overline{o}\))從而得到一個離散的網絡結構參數,
為了在所有混合操作中共同學習體系結構α和權重w,DARTS使用梯度下降的方法來優化損失值。
下面將\(\mathcal{L}_{train},\mathcal{L}_{val}\)分別表示訓練集和驗證集損失值。二者均由\(α\)和\(w\)決定。最終的優化目標是找到在滿足\(w^*=argmin_w \,\, \mathcal{L}_{train}(w,α)\)的前提下找到使得\(\mathcal{L}_{val}(w^*,α^*)\)最小化的\(α^*\),用公式表示如下:
\[
min_{\,α} \,\, \mathcal{L}_{val}(w^*(α),α) \tag{1}
\]
\[
s.t. \,\,\, w^*(α)=argmin_{\,w} \,\, \mathcal{L}_{train}(w,α) \tag{2}
\]
s.t. = subject to,表示需要滿足後面的條件,即公式(1)需要在滿足公式(2)的情況下計算
3.Approximation
精確的計算雙層優化問題是很困難的,因為只要α發生任何變化,就需要通過求解公式(2)來重新計算\(w^*(α)\)。
所以本文提出了近似叠代優化過程,其中w和α通過分別在權重和架構空間中的梯度下降步驟之間交替來優化(算法見下圖Alg.1)。
算法說明:
假設在第k步,給定當前網絡結構\(α_{k-1}\),我們通過\(\mathcal{L}_{train}(w_{k-1},α_{k-1})\)計算梯度更新得到\(w_k\)。然後固定\(w_k\),通過更新網絡結構\(a_k\)來最小化驗證集損失值(公式3),其中\(\xi\)表示學習率。
\[\mathcal{L}_{val}(w',a_{k-1}) = \mathcal{L}_{val}(w_k-\xi \nabla_w \mathcal{L}_{train}(w_{k},α_{k-1}),a_{k-1}) \tag{3}\]
網絡結構梯度是通過將公式3對\(α\)求導得到的,結果如公式4(為方便書寫,用於表示步驟的k已省略)所示:
\[\nabla_α \mathcal{L}_{val}(w',α) - \xi \nabla^2_{α,w} \, \mathcal{L}_{train}(w,α) \nabla_{w'}\mathcal{L}_{val}(w',α) \tag{4}\]
上式中的第二項包含一個矩陣向量積,這是非常難計算的。但是我們知道微分可以通過如下公式進行近似:
\[f'(x)=\frac{f(x+\epsilon)-f(x-\epsilon)}{2\epsilon}\]
所以有:
\[\nabla^2_{α,w} \, \mathcal{L}_{train}(w,α) \nabla_{w'}\mathcal{L}_{val}(w',α) ≈ \frac{ \nabla_α \mathcal{L}_{train}(w^+,α) - \nabla_α \mathcal{L}_{train}(w^-,α) }{2\epsilon} \tag{7}\]
其中\(w^{+}=w+\epsilon \nabla_{w'}\mathcal{L}_{val}(w',α),w^{-}=w-\epsilon \nabla_{w'}\mathcal{L}_{val}(w',α)\)
計算有限差分只需要兩次權值前傳和兩次向後傳遞(α),復雜度從\(\mathcal{O}(|α||w|)\)降低為\(\mathcal{O}(|α|+|w|)\)。
4.Deriving Discrete Architecture
在求得連續模型結構編碼\(α\)之後,離散網絡結構求解方式如下:
Evaluation
作者如何評估自己的方法,有沒有問題或者可以借鑒的地方
Conclusion
貢獻如下:
- 引入了一種適用於卷積和循環結構的可微分網絡體系結構搜索的新算法。
- 通過實驗表明我們的方法具有很強的競爭力。
- 實現了卓越的結構搜索效率(4個GPU:1天內CIFAR10誤差2.83%; 6小時內PTB誤差56.1),這歸因於使用基於梯度的優化而非非微分搜索技術。
- 我們證明DARTS在CIFAR-10和PTB上學習的體系結構可以遷移到ImageNet和WikiText-2上
Notes
疑問:relaxation操作是什麽意思?為什麽使用softmax能將操作連續化?即公式(1)是什麽意思?\(α\)又是什麽?
https://baike.baidu.com/item/%E6%9D%BE%E5%BC%9B%E6%B3%95/12508962?fr=aladdin?
論文筆記系列-DARTS: Differentiable Architecture Search