1. 程式人生 > >模型驅動的深度學習(ADMM-net)

模型驅動的深度學習(ADMM-net)

for 高精 高精度 不同 height 梯度 深度學習 減少 需求

流程:模型族->算法族->深度網絡->深度學習

模型族:模型中含有超參數,給予不同的參數對應不同的模型,就形成了模型族

算法族:每一個模型對應一個完整算法,整個模型族對應了一個算法族 將算法族展開成一個深度網絡,網絡層數代表叠代次數,模型的超參數成為網絡中的參數(如權重等)。利用少量標記數據就可以訓練網絡。

相對於模型驅動算法的優勢:

   1、可以學習模型超參數,提高了模型的適應能力,提高精度

相對於數據驅動的優勢:

   1、網絡的設計有模型指導

   2、減少了數據需求量

   3、減小了訓練時間

比如核磁共振重建的ADMM算法:

模型:

\(x^*={arg\max}_{x}{\{\frac{1}{2}||Ax-y||^2+\sum_{l=1}^{L}\lambda_{l}g(D_{l}x)\}}\)

ADMM算法求解:

技術分享圖片

\(g,\lambda,L,D_{l}\)的不同選擇形成了不同的模型,構成了模型族。

廣義拉格朗日函數:

技術分享圖片

ADMM算法叠代更新過程:

技術分享圖片

令\(\beta_{l}=\frac{\alpha_{l}}{\rho_{l}},A=PF\)(已知),可得

技術分享圖片

\(S(\cdot)\)是一個非線性shrinkage function。\(S(\cdot)\)通常是一個光滑函數。

網絡結構:

技術分享圖片

包括重建層\(X^{(n)}\)、卷積層\(C^{(n)}=D_{l}x^{(n)}\)、非線性變換層\(Z^{(n)}\)、乘子更新層\(M^{(n)}\),其中非線性變換函數\(S\)可以用分段線性函數近似,只需學習插值點的函數值即可。

網絡學習的參數:模型族中的超參數,每一層可以不一樣。

網咯訓練:

損失函數為

技術分享圖片

梯度下降法訓練。

參考文獻:

yangyan,sunjian,lihuibin,xuzongben, Deep ADMM-Net for Compressive Sensing MRI (NIPS2017)

https://arxiv.org/abs/1705.06869

模型驅動的深度學習(ADMM-net)