1. 程式人生 > >Matching Networks for One Shot Learning論文分析

Matching Networks for One Shot Learning論文分析

Matching Networks for One Shot Learning

Abstract

研究領域: One Shot Learning(小樣本學習)從少量樣本中快速學習,是傳統監督學習和Deep Learning無法解決的問題,該研究領域被稱為小樣本學習

創新:
以下兩種方法結合:

  • metric learning目前,小樣本學習的主流方法
  • external memories以前小樣本學習的主流方法

資料集: Omniglot & ImageNet

1. Introduction

  1. 人類可以從少量樣本中學習新的概念。比如:一個小朋友看到鄰居家的新玩具一次,下次跟媽媽去商場的時候馬上就能從貨架上認出它來。
  2. 現在Deep Learning仍然需要大資料的驅動。
  3. 一些non-parametric model可以快速學習新樣本,比如KNN。本文要融合parametric model(即DL)和non-parametric model。DL中的樣本是用完即棄的,而KNN中的樣本會被儲存。
  4. 本文還為在Omniglot & ImageNet上的One Shot Learning實驗設定了benchmark。

2. Model

2.1 Model Architecture

Matching Net architecture

  1. 在網路上加external memories。
  2. external memories有很多種。在seq2seq中,external memories用於對P
    (BA),whereAandBcanbeasequenceP(B|A),\ where\ A\ and\ B\ can\ be\ a\ sequence
    的建模。在本文的Matching Net中,也用這種方式,只不過這裡的A,BA,\ B是一個set。如上圖所示,網路的輸入是有多個圖片組成的set。
  3. 數學建模部分。這裡比較複雜,我會講的詳細一點。

從上圖可以看到,左邊4個圖片形成一組,稱為support set;右下1個單身狗,稱為test example。全部5個圖片稱為1個task。

該模型用函式可表示為predic

tion=f(support_set,test_example)prediction = f(support\_set,\ test\_example),即模型有兩個輸入。該模型用概率可表示為P(y^x^,S)P(\hat y|\hat x, S), 其中S={(xi,yi)}i=1kS = \{(x_i, y_i)\}_{i=1}^k,k表示support set中樣本的個數。上圖support set有4個圖片,k=4。

Matching Net作者把該模型表示為:
y^=i=1ka(x^,xi)yi\hat y = \sum_{i=1}^k a(\hat x, x_i) y_i

預測值y^\hat y被看做是support set中樣本的labels的線性組合,組合的權重是test example和support set中1個樣本的關係——a(x^,xi)a(\hat x, x_i)

  • a(x^,xi)a(\hat x, x_i)作為一個核函式,則該模型可近似為:Deep Learning做嵌入層,KDE做分類層
  • a(x^,xi)a(\hat x, x_i)作為一個01函式,則該模型可金思維:Deep Learning做嵌入層,KNN做分類層

2.1.1 The Attention Kernel

本文賦予a(x^,xi)a(\hat x, x_i)新的形式——將它看做attention kernel。此時,模型的預測結果就是support set中attention最多的圖片的label。

常見的attention kernel是cosine距離上的softmax:
a(x^,xi)=ec(f(x^),g(xi))j=1kec(f(x^),g(xj))a(\hat x, x_i) = \frac {e^{c(f(\hat x), g(x_i))}}{\sum_{j=1}^k e^{c(f(\hat x), g(x_j))}},其中f,gf, g是兩個嵌入函式(可由神經網路實現,如:VGG or Inception)。

2.1.2 Full Context Embeddings

嵌入向量emb_xi=g(xi)g(xi,S)emb\_x_i = g(x_i) \leftarrow g(x_i, S),嵌入函式的輸出同時由對應的xix_i和整個support set有關。support set是每次隨機選取的,嵌入函式同時考慮support set和xix_i可以消除隨機選擇造成的差異性。類似機器翻譯中word和context的關係,SS可以看做是xix_i的context,所以本文在嵌入函式中用到了LSTM。

對text example的嵌入函式為ff:
f(x^,S)=attLSTM(f(x^),g(S),K)f(\hat x, S) = \textbf{attLSTM}(f'(\hat x), g(S), K),其中f(x^)f'(\hat x)是CNN嵌入層的輸出,可以是VGG或Inception,g(S)g(S)是support set中樣本的嵌入函式輸出,K是LSTM層的timesteps,等於support set的圖片個數。

詳解full context embedding:

The Fully Conditional Embedding f

h^k,ck=LSTM(f(x^),[hk1,rk1],ck1)\hat h_k, c_k = LSTM(f'(\hat x), [h_{k-1}, r_{k-1}], c_{k-1})

hk=h^k+f(x^)h_k = \hat h_k + f'(\hat x)

rk1=i=1Sa(hk1,g(xi))g(xi)r_{k-1} = \sum_{i=1}^{|S|} a(h_{k-1}, g(x_i))g(x_i)

a(hk1,g(xi))=softmax(hk1Tg(xi))a(h_{k-1}, g(x_i)) = softmax(h_{k-1}^Tg(x_i))

The Fully Conditional Embedding g

support set中的xix_i在經過多層卷積網路後,在經過一層bidirectional LSTM。

2. 2Training Strategy

  • 一個batch包括多個task;
  • 一個task包括一個support set和一個test example;
  • 一個support set包括多個sample(image & label);
  • support set中有且只有一個樣本與test example同類。

Related Work

Memory Augumented Neural Networks attention機制
Metric Learning 比較學習