關於多工學習(MTL),我們應該知道的事
概念相關
1. Multi-task learning
關於多工學習的定義並沒有統一的標準,這裡引用《A survey on multi-task learning》中的定義:
與標準的單任務相比,在學習共享表示的同時訓練多個任務有兩個主要挑戰:
- Loss Function(how to balance tasks):多工學習的損失函式,對每個任務的損失進行權重分配,在這個過程中,必須保證所有任務同等重要,而不能讓簡單任務主導整個訓練過程。手動的設定權重是低效而且不是最優的,因此,自動的學習這些權重或者設計一個對所有權重具有魯棒性的網路是十分必要和重要的。
- Network Architecture(how to share):
2. Auxiliary Learning(輔助學習):
除了同時學習多個任務,在有些情況下,我們的關注點只是多工中的一個或者幾個任務的表現。為了更好的理解任務之間的相關性,我們可以通過設定帶有各種屬性的輔助任務來進行。輔助任務的目的就是協助我們找到一個更強大,更具有魯棒性的特徵表示,最終讓主要任務受益。關於輔助任務的定義,我們可以根據上文的多工定義進行延伸,如下表示:
3. Multi-Task Framework Design
這一章節我將以問答的形式進行,通過提出問題,解決問題,讓大家更容易理解多工的框架設計方法
Q:How to properly balance different types of tasks such that training multi-task networks will not be dominated by the easier task(s)?
分析:第一個問題是,在設計多工網路過程中,我們如何平衡不同型別的任務,避免在訓練過程中,整個網路被簡單任務主導,導致任務之間的效能差異巨大。這就涉及到為不同任務的loss function賦上不同的權重,將不同task之間的loss統一成一個損失函式,如果只是簡單的將不同任務的loss相加,這樣會造成最終模型在有些任務上表現很好,在有的任務上大失水準。背後的原因是不同任務的不同損失函式尺度有很大的差異,因此需要考慮用權值將每個損失函式的尺度統一。
A:針對這個問題,最新的解決辦法是cvpr2018的一個工作《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》,這篇文章提出,將不同的loss拉到統一尺度下,這樣就容易統一,具體的辦法就是利用同方差的不確定性,將不確定性作為噪聲,進行訓練,詳細的講解可以看我專欄文章:
路一直都在:利用不確定性來衡量多工學習中的損失函式zhuanlan.zhihu.com
這裡在簡單的講一下同方差的不確定性(Homoscedastic Uncertainty):屬於偶然不確定性,這種不確定性捕捉了不同任務之間的相關性置信度,這種不確定性可以作為不同任務loss賦值的衡量標準。
這裡,引用一下相關文章,我會在文末附上參考,若侵權,請與我聯絡
其中,Fw是神經網路的輸出,在一個有K個任務的模型中,似然估計可以表示為通過概率累乘得到,則極大似然估計可以寫成:
其中,σ類似神經網路的引數w,都是可以通過反向傳播進行更新,表示的是每個任務輸出的置信度。分析上式可知,如果σ增加,相對應的任務loss的權重就會減小,這樣就實現了權重的動態規劃。
Q:How to build a multi-task learning architecture which is easy to train,parameter-efficient and robust to task weighting?
分析:如何構建一個統一,易訓練,高魯棒的多工網路,有多種思想,但是,一個優秀的多工網路應該具備:(1)特徵共享部分和任務特定部分都能自動學習(2)對損失函式權重的選擇上更robust
A:如下圖所示,關於特徵共享表示,一般有兩種方法,Hard-parameter sharing和soft-parameter-sharing。hard-parameter sharing有一組相同的特徵共享層,這種設計大大減少了過擬合的風險;soft-parameter sharing每一個任務都有自己的特徵引數和每個子任務之間的約束,這種設計更robust。
當下最主流的框架都是兩種框架的結合,通過結合,能夠找到特徵共享部分和特定任務部分很好的協調,下面介紹常見的多工網路的結構設計:
- Fusion Network(融合網路)
Fusion Network是一種通用的特徵學習網路,每個任務的上層共享表示是通過學習特定任務的引數,將所有任務的低層特徵表示通過線性組合表示出來。代表的網路結構有“Cross-Stitch Network”十字繡網路,瞭解更多關於該網路,可以去看論文原文。(論文解讀可以看我的專欄文章,點我跳轉)
十字繡網路結構如下圖所示,十字繡網路只在池化層或者全連線層之後加上十字繡單元
計算過程可以用下式表示:
可以通過設定AB和BA的值控制共享的程度。