【元學習】Meta Learning 介紹
目錄
- 元學習(Meta-learning)
- 元學習被用在了哪些地方?
- Few-Shot Learning(小樣本學習)
- 最近的元學習方法如何工作
- Model-Agnostic Meta-Learning (MAML)
元學習(Meta-learning)
智慧的一個關鍵方面是多功能性——做許多不同事情的能力。當前的AI系統可以做到精通於某一項技能,但是,如果我們要求AI系統執行各種看似簡單的問題(用同一個模型去解決不同問題),它將會變得十分困難。相反,人類可以明智地利用以往經驗並採取行動以適應各種新的情況。因此我們希望 agent 能夠像人類一樣利用以往經驗來解決新的問題,而不是將解決新問題的方法從頭學起。Learning to learn 或者 meta-larning 是朝這個方向發展的關鍵一步,它們可以在其生命週期內不斷學習各種任務。
元學習被用在了哪些地方?
元學習通常被用在:優化超引數和神經網路、探索好的網路結構、小樣本影象識別和快速強化學習等。
Few-Shot Learning(小樣本學習)
2015年, Brendan Lake et al. 發表一篇論文,這對現代機器學習方法提出了挑戰,他認為機器可以從一個或者幾個例項中學習到新概念。後繼的兩篇論文 memory-augmented neural networks 和 sequential generative models 表明,對深度模型來說,其可以從少量例項中進行學習,儘管還沒有達到人類水平。
最近的元學習方法如何工作
元學習系統要接受大量任務(tasks)的訓練,並預測其學習新任務的能力。這種任務可能是對新影象分類(給定每個類別就只有一個示例),其中有5種可能類別;或者是這樣一種任務,僅通過學習一個迷宮就可以有效地在新的迷宮中導航。這與許多標準的機器學習技術不同,後者涉及對單個任務地訓練,並且保留了一些示例對該任務進行測試。下圖是影象分類領域運用元學習的示例:
在元學習過程中,訓練模型以學習 meta-training set 中的任務,這其中有兩個優化在起作用:learner:學習新任務;meta-learner:訓練 learner。元學習的方法通常分為三類:(1)recurrent models,(2)metric learning, (3)learning optimizers。
這裡重點介紹第三種方法,即學習一個優化器。在這種方法裡,有兩個網路,分別是 meta-learner 和 learner。前者學習如何更新後者,使得後者能夠有效學習新任務。這種方法已被用於研究更好的神經網路優化。meta-learner 通常是迴圈網路(recurrent network),這樣才能記得之前是怎樣更新 learner 模型。此外,meta-learner 可以使用強化學習或者監督學習進行訓練。
Model-Agnostic Meta-Learning (MAML)
MAML 的思路就是直接針對初始表示進行優化,其中這種初始表示可以通過少量示例進行有效地調整。像其他 meta-learning 方法一樣,MAML 也是通過許多 tasks 進行訓練,訓練所得表徵可以通過很少梯度迭代就能適應新任務。MAML 試圖尋找這樣一種初始化,不僅有效適用不同任務,而且要快速適應(僅需要幾步)和有效適應(只使用很少樣例)。觀看下圖,假設我們正在尋找一組有很強適應性的引數 \(\theta\) 。在元學習過程中(實線部分),MAML 針對一組引數進行優化,以使得對特定任務 \(i\) (灰線部分)採取梯度步驟時,這些引數可以接近最佳引數 \(\theta_i^*\) 。
以 MAML 為例介紹元學習一些相關概念
1. N-way K-shot:這是 few-shot learning 中常見的實驗設定,N-way 指訓練資料中有 N 個類別,K-shot 指每個類別下有 K 個被標記資料。
2. model-agnostic:即指模型無關。MAML 相當於一個框架,提供一個 meta learner 用於訓練 learner。meta-learner 是 MAML 的精髓所在,用於 learning to learn;而 learner 則是在目標資料集上被訓練,並實際用於預測任務的真正數學模型。絕大多數深度學習模型都可以作為 learner 無縫嵌入 MAML 中,MAML 甚至也可以用於強化學習中,這就是 MAML 中模型無關的含義。
3. task:這在 MAML 中是一個很重要的概念。我們首先需要了解的概念:\(D_{meta-train}, D_{meta-test}\),support set,query set,meta-train classes,meta-test classes等等。假設一個這樣的場景:我們需要利用 MAML 訓練一個數學模型 \(M_{fine-tune}\),目的是對未知標籤圖片做分類,類別包括\(P_1 \sim P_5\)(每類有 5 個已標註樣本用於訓練,另外 15 個已標註樣本用於測試)。我們的訓練資料除了 \(P_1 \sim P_5\) 中已標註的樣本外,還包括另外 10 個類別的圖片 \(C_1 \sim C_{10}\)(每類有 30 個已標註樣本),用於幫助訓練元學習模型 \(M_{meta}\)。
此時, \(C_1 \sim C_{10}\) 即為 meta-train classes, \(C_1 \sim C_{10}\) 包含的 300 個樣本即為 \(D_{meta-train}\),作為訓練 \(M_{meta}\) 的資料集。與此相對, \(P_1 \sim P_{5}\) 即為 meta-test classes, \(P_1 \sim P_{5}\) 包含的 100 個樣本即為 \(D_{meta-test}\),作為訓練和測試 $M_{fine-tune} $ 的資料集。
我們的實驗設定為5-way 5-shot,因此在 \(M_{meta}\) 階段,我們從 \(C_1 \sim C_{10}\) 中隨機選取 5 個類別,每個類別再隨機選取 20 個已標註樣本,組成一個 Task \(\text{T}\),其中的 5 個已標註樣本稱為 \(\text{T}\) 的 support set,另外 15 個樣本稱為 \(\text{T}\) 的 query set。這個 Task \(\text{T}\) 相當於普通深度學習模型訓練過程的一個數據,因此我們需要反覆在訓練資料分佈中抽取若干個 \(\text{T}\) 組成 batch ,才能使用隨機梯度下降 SGD。
MAML 演算法流程
這裡內容主要來自知乎:https://zhuanlan.zhihu.com/p/57864886
以上是 MAML 預訓練階段的演算法,目的是得到模型 \(M_{meta}\) 。下面是逐行分析:
首先是前兩個 Require。第一個 Require 指的是 \(D_{meta-train}\) 中 task 的分佈,我們可以反覆隨機抽取 task,形成一個由若干個 \(\text{T}\) 組成的 task 池,作為 MAML 的訓練集。第二個 Require 就是學習率,MAML 是基於二重梯度的,每次迭代包含兩次引數更新的過程,所以有兩個學習率可以調整。
步驟1:隨機初始化模型引數;
步驟2:是一個迴圈,可以理解為一輪迭代過程或一個 Epoch,當然,預訓練過程也可以有多個 Epoch,相當於設定 Epoch;
步驟3:隨機對若干個(e.g., 4 個)task 進行取樣,形成一個 batch;
步驟4 \(\sim\) 步驟7:第一次梯度更新過程。注意這裡我們可以理解為copy了一個原模型,計算出新的引數,用在第二輪梯度的計算過程中。 我們說過,MAML是gradient by gradient的,有兩次梯度更新的過程。步驟4~7中,利用batch中的每一個task,我們分別對模型的引數進行更新(4個task即更新4次)。 注意這個過程在演算法中是可以反覆執行多次的,但是虛擬碼沒有體現這一層迴圈 。
步驟5: 利用 batch 中的某一個 task 中的 support set( 在 N-way K-shot 的設定下,這裡的support set 應該有 NK 個 ),計算每個引數的梯度。 注意: 這裡的loss計算方法,在迴歸問題中,就是MSE;在分類問題中,就是cross-entropy。
步驟6:第一次梯度的更新。
步驟4 \(\sim\) 步驟7: 結束後,MAML完成了第一次梯度更新。接下來我們要做的,是根據第一次梯度更新得到的引數,通過gradient by gradient,計算第二次梯度更新。第二次梯度更新時計算出的梯度,直接通過SGD作用於原模型上,也就是我們的模型真正用於更新其引數的梯度。
步驟8:這裡對應第二次梯度更新的過程。這裡的loss計算方法,大致與步驟5相同,但是不同點有兩處:第一處是我們不再分別利用每個task的loss更新梯度,而是像常見的模型訓練過程一樣,計算一個batch的loss總和,對梯度進行隨機梯度下降SGD;第一處是這裡參與計算的樣本,是task中的 query set,在我們的例子中,即5-way*15=75個樣本,目的是增強模型在task上的泛化能力,避免過擬合 support set。步驟8結束後,模型結束在該batch中的訓練,開始回到步驟3,繼續取樣下一個batch。
以上便是 MAML 預訓練得到 \(M_{meta}\) 的全部過程。
接下來,在面對新的 task 時,我們將在 \(M_{meta}\) 的基礎上,精調(fine-tune)得到 \(M_{fine-tune}\) 。
精調過程於預訓練過程大致相同,不同之處有以下幾點:
- 步驟 1 中,fine-tune 不用再隨機初始化引數,而是利用訓練好的 \(M_{meta}\) 初始化引數;
- 步驟 3 中,fine-tune只需要抽取一個task進行學習,自然也不用形成batch。fine-tune利用這個task的support set訓練模型,利用query set測試模型。 實際操作中,我們會在 \(D_{meta-test}\) 上隨機抽取多個 task(e.g., 500 個),分別微調模型 \(M_{meta}\),並對最後測試結果進行平均,避免極端情況;
- fine-tune 沒有步驟 8, 因為task的query set是用來測試模型的,標籤對模型是未知的。因此fine-tune過程沒有第二次梯度更新,而是直接利用第一次梯度計算的結果更新引數。
References:
[1] 深度學習小站——知乎
[2] From zero to research — An introduction to Meta-learning
[3] BAIR——MAML作者的 bolg
[4] Meta-Learning: Learning to Learn Fast