1. 程式人生 > 實用技巧 >GOLD-NAS:針對神經網路可微分架構搜尋的一次大手術

GOLD-NAS:針對神經網路可微分架構搜尋的一次大手術

GOLD-NAS:針對神經網路可微分架構搜尋的一次大手術

本文是華為,清華與同濟大學聯合發表的基於可微分網路搜尋的論文。本文基於傳統DARTS搜尋空間受限,二階優化不穩定,超網路離散化誤差大的這三個問題,對DARTS進行了一場全面的手術。本文針對上述三個問題提出了漸進式的(Gradual),一階段的(One-Level),可微分的(Differentiable) 神經網路架構搜尋(GOLD-NAS)

。在標準的影象分類基準中,GOLD-NAS 可以找到單個搜尋過程中的一系列的帕累托最優架構。

  • 論文題目:GOLD-NAS: Gradual, One-Level, Differentiable
  • 開原始碼:https://github.com/sunsmarterjie/GOLD_NAS

DARTS目前的問題

搜尋空間受限

DARTS的搜尋空間非常有限,例如,對於每個邊保留了一個運算子,每個節點固定接收兩個前繼輸入,等等。這些約束有利於NAS搜尋的穩定性,但它們也縮小了強大的搜尋方法帶來的準確性。最典型的,某些啟發式設計(例如,每個單元格中有兩個跳連運算子)或搜尋技巧(例如,Early Stop),甚至隨機搜尋也可以達到令人滿意的效果。

雙層優化不穩定

DARTS需要雙層優化,即訓練階段以優化網路權重驗證階段以更新體系結構引數。 這種機制帶來了計算負擔,更重要的是,梯度估計存在很大的不準確性可能會大大降低搜尋過程

離散化誤差大

在超級網路建立之後,DARTS會立即剪掉弱的運算子和邊,但此步驟可能會帶來較大的離散誤差,尤其是當權重為被修剪的運算子不能保證很小。

GOLD-NAS的解決方案

搜尋空間重定義

  • 不同的Cell可以擁有不同的結構
  • 每條邊可以包含超過一個操作
  • 每條邊只保留兩個操作(Skip-Connect 和 Sep-conv-3x3)
  • 每個節點可以選擇從任數量的前繼作為輸入
  • 訓練過程中,操作引數從原始的Softmax歸一化(競爭)修改為Signoid元素化(獨立)

一階段優化

可微分NAS的目標是解決以下優化問題:

α ⋆ = arg ⁡ min ⁡ α L ( ω ⋆ ( α ) , α ; D train ⁡ ) , s.t. ω ⋆ ( α ) = arg ⁡ min ⁡ ω L ( ω , α ; D train ) \boldsymbol{\alpha}^{\star}=\arg \min _{\boldsymbol{\alpha}} \mathcal{L}\left(\boldsymbol{\omega}^{\star}(\alpha), \boldsymbol{\alpha} ; \mathcal{D}_{\operatorname{train}}\right), \quad \text { s.t. } \quad \boldsymbol{\omega}^{\star}(\alpha)=\arg \min _{\boldsymbol{\omega}} \mathcal{L}\left(\boldsymbol{\omega}, \boldsymbol{\alpha} ; \mathcal{D}_{\text {train }}\right) α=argαminL(ω(α),α;Dtrain),s.t.ω(α)=argωminL(ω,α;Dtrain)

一階段優化目標旨在同步更新架構引數 α ⋆ \boldsymbol{\alpha}^{\star} α和網路權重 ω ⋆ \boldsymbol{\omega}^{\star} ω:

ω t + 1 ← ω t − η ω ⋅ ∇ ω L ( ω t , α t ; D train ⁡ ) , α t + 1 ← α t − η α ⋅ ∇ α L ( ω t , α t ; D train ⁡ ) \omega_{t+1} \leftarrow \omega_{t}-\eta_{\omega} \cdot \nabla_{\omega} \mathcal{L}\left(\omega_{t}, \alpha_{t} ; \mathcal{D}_{\operatorname{train}}\right), \quad \alpha_{t+1} \leftarrow \alpha_{t}-\eta_{\alpha} \cdot \nabla_{\alpha} \mathcal{L}\left(\omega_{t}, \alpha_{t} ; \mathcal{D}_{\operatorname{train}}\right) ωt+1ωtηωωL(ωt,αt;Dtrain),αt+1αtηααL(ωt,αt;Dtrain)

根據NAS的搜尋架構我們瞭解到架構引數 α ⋆ \boldsymbol{\alpha}^{\star} α(10數量級)和網路權重 ω ⋆ \boldsymbol{\omega}^{\star} ω(百萬數量級)存在很大的引數數量差距。因此,在之前的訓練中往往採用不同的優化器設定不同的優化引數,但是,由於引數數量上的差距優化器仍然會趨向於優化網路引數 ω ⋆ \boldsymbol{\omega}^{\star} ω。本文,將訓練集進一步切分為兩部分 D train = D 1 ∪ D 2 \mathcal{D}_{\text {train }}=\mathcal{D}_{1} \cup \mathcal{D}_{2} Dtrain=D1D2,分別訓練架構引數和網路權重:

α ⋆ = arg ⁡ min ⁡ α L ( ω ⋆ ( α ) , α ; D 1 ) , s.t. ω ⋆ ( α ) = arg ⁡ min ⁡ ω L ( ω , α ; D 2 ) \boldsymbol{\alpha}^{\star}=\arg \min _{\boldsymbol{\alpha}} \mathcal{L}\left(\boldsymbol{\omega}^{\star}(\alpha), \boldsymbol{\alpha} ; \mathcal{D}_{1}\right), \quad \text { s.t. } \quad \boldsymbol{\omega}^{\star}(\alpha)=\arg \min _{\boldsymbol{\omega}} \mathcal{L}\left(\boldsymbol{\omega}, \boldsymbol{\alpha} ; \mathcal{D}_{2}\right) α=argαminL(ω(α),α;D1),s.t.ω(α)=argωminL(ω,α;D2)

ω t + 1 ← ω t − η ω ⋅ ∇ ω L ( ω t , α t ; D 2 ) , α t + 1 ← α t − η α ⋅ ∇ α L ( ω t + 1 , α t ; D 1 ) \omega_{t+1} \leftarrow \omega_{t}-\eta_{\omega} \cdot \nabla_{\omega} \mathcal{L}\left(\omega_{t}, \alpha_{t} ; \mathcal{D}_{2}\right), \quad \alpha_{t+1} \leftarrow \alpha_{t}-\eta_{\alpha} \cdot \nabla_{\alpha} \mathcal{L}\left(\omega_{t+1}, \alpha_{t} ; \mathcal{D}_{1}\right) ωt+1ωtηωωL(ωt,αt;D2),αt+1αtηααL(ωt+1,αt;D1)

一級優化往往存在更嚴重的搜尋不穩定問題,文中給出的解決方法是:針對小資料集(CIFAR-10)可以在訓練過程中新增正則化(例如,Cutout或AutoAugment); 亦或是直接在大資料集(ImageNet)上搜索

基於資源約束的漸進式剪枝

傳統的DARTS在超網路訓練完成後按規則進行離散化剪枝,但是在兼職過程中會產生巨大的離散化誤差。本文,為了解決離散化誤差的問題採用的漸進式剪枝過程,並且在多次剪枝過程中,每次剪掉引數趨於0的操作,儘量避免因剪枝造成的離散誤差。另外,為了實現架構引數在訓練過程中趨於0或者1,本文添加了基於資源約束的正則化Loss:

L ( ω , α ) = E ( x , y ⋆ ) ∈ D train [ C E ( f ( x ) , y ⋆ ) ] + λ ⋅ ( F L O P s ‾ ( α ) + μ ⋅ F L O P s ( α ) ) \mathcal{L}(\boldsymbol{\omega}, \boldsymbol{\alpha})=\mathbb{E}_{\left(\mathbf{x}, \mathbf{y}^{\star}\right) \in \mathcal{D}_{\text {train }}}\left[\mathrm{CE}\left(f(\mathbf{x}), \mathbf{y}^{\star}\right)\right]+\lambda \cdot(\overline{\mathrm{FL} \mathrm{OPs}}(\boldsymbol{\alpha})+\mu \cdot \mathrm{FLOPs}(\boldsymbol{\alpha})) L(ω,α)=E(x,y)Dtrain[CE(f(x),y)]+λ(FLOPs(α)+μFLOPs(α))

GOLD-NAS 演算法流程圖

GOLD-NAS 演算法流程圖

結果

帕累託邊界

帕累託邊界

CIFAR-10 結果

CIFAR-10 結果

CIFAR-10搜尋結果視覺化:紅色線代表Skip-Connect;藍色線代表sep-conv-3x3

ImageNet 結果

ImageNet 結果
ImageNet搜尋結果視覺化:紅色線代表Skip-Connect;藍色線代表sep-conv-3x3


更多內容關注微信公眾號【AI異構】