1. 程式人生 > >triplet loss 原理以及梯度推導

triplet loss 原理以及梯度推導

【前言】  最近,learning to rank 的思想逐漸被應用到很多領域,比如google用來做人臉識別(faceNet),微軟Jingdong Wang 用來做 person-reid 等等。learning to rank中其中重要的一個步驟就是找到一個好的similarity function,而triplet loss是用的非常廣泛的一種。

【理解triplet】

è¿éåå¾çæè¿°

如上圖所示,triplet是一個三元組,這個三元組是這樣構成的:從訓練資料集中隨機選一個樣本,該樣本稱為Anchor,然後再隨機選取一個和Anchor (記為x_a)屬於同一類的樣本和不同類的樣本,這兩個樣本對應的稱為Positive (記為x_p)和Negative (記為x_n),由此構成一個(Anchor,Positive,Negative)三元組。

【理解triplet loss】 

有了上面的triplet的概念, triplet loss就好理解了。針對三元組中的每個元素(樣本),訓練一個引數共享或者不共享的網路,得到三個元素的特徵表達,分別記為:è¿éåå¾çæè¿°triplet loss的目的就是通過學習,讓x_a和x_p特徵表達之間的距離儘可能小,而x_a和x_n的特徵表達之間的距離儘可能大,並且要讓x_a與x_n之間的距離和x_a與x_p之間的距離之間有一個最小的間隔:α 。公式化的表示就是: 

對應的目標函式也就很清楚了: 

è¿éåå¾çæè¿°   這裡距離用歐式距離度量,+表示[ ]內的值大於零的時候,取該值為損失,小於零的時候,損失為零。 

由目標函式可以看出:

當x_a與x_n之間的距離 < x_a與x_p之間的距離加α時,[ ]內的值大於零,就會產生損失。 當x_a與x_n之間的距離 >= x_a與x_p之間的距離加α時,損失為零。

【triplet loss 梯度推導】 

上述目標函式記為L。則當第i個triplet損失大於零的時候,僅就上述公式而言,有: 

è¿éåå¾çæè¿° 【演算法實現時候的提示】  可以看到,對x_p和x_n特徵表達的梯度剛好利用了求損失時候的中間結果,給的啟示就是,如果在CNN中實現 triplet loss layer, 如果能夠在前向傳播中儲存著兩個中間結果,反向傳播的時候就能避免重複計算。這僅僅是演算法實現時候的一個Trick。