1. 程式人生 > 實用技巧 >Self-paced Co-training 程式碼解析

Self-paced Co-training 程式碼解析

Self-paced Co-training 程式碼解析

曹世磊 學生

前言

Self-paced Co-training發表在icml2017,是馬凡師兄還在西安交通大學的時候跟著孟德宇老師做的。本文主要從code(見文末引用)角度對這篇文章進行解析,但考慮到講code不能脫離文章,所以也會對文章進行簡要回顧。

文章簡要回顧

  • Co-training理解。最早由Blum he Mitchell在1998年提出(見Combining labeled and unlabeled data with co-training),是一種常見的半監督學習方法。參考
    [2]
    的說法,co-training核心就是利用少量已標記樣本,通過兩個(或多個)model/view去學習。利用多個model/view的(中間)訓練結果對未標記樣本進行交替預測,挑選出most confidently的樣本加入到已標記樣本陣營,然後再重複上述過程,直到收斂條件達到為止。co-training目前主要存在兩種方法:single-view 和 multi-view。最開始提出的是multi-view,就是對特徵進行拆分,使用相同的模型,來保證模型間的差異性。後來論證了single-view方法,也就是採用不同種類的模型,但是採用全部特徵,也是可以的。基於後一種方法,好多開始做整合方法,採用boosting
    方式,加入更多分類器,當然也是可以同時做特徵的取樣。
  • 本文的insight。本文的insight還是很震撼的。傳統的cotraining方法通過交替訓練多個model/view生產pseudo-label,從而擴充訓練集。這裡就存在一個問題:模型在迭代前期,生產的pseudo-label很多都是錯誤的,這些錯誤的pseudo-label在後期的模型訓練中不停地被利用,從而干擾了模型的訓練。那能不能解決這一問題呢?這篇文章就是朝著這個方向努力的。他可以做到,動態的選擇sample,動態選擇具有draw with replacement的特性,從而解決了以前一次選擇,終身就不能逃離的困境。給出答案前,需要引入
    孟德宇
    老師另一重要的工作--self-paced learning。這裡簡單對Self-paced Learning進行介紹,想深入研究的可以看下孟老師的主頁。我們在求解機器學習模型時,通常有一個目標函式,一般就是模型對標註樣本的損失函式。但是我們知道,不可能存在完美的標註樣本,在一些情況下,這些標註的樣本可能本身就是錯誤標註的。這時就需要把這些樣本在損失函式的計算中去掉(當然Self-paced Learning的另一種解釋就是模型訓練遵循easy to complex的形式)。怎麼做到呢?這時就要發揮數學的作用了。公式(1)給出Self-paced Learning通用形式。我們把重點放在第一項上,觀察可以發現,我們控制的取值來控制當前樣本對對目標函式的貢獻量。極端情況全為1的時候,就是傳統的損失函式。某些情況為0的時候,這部分樣本就不會納入計算,全為0的話,目標函式就無意義了(當然一般不存在這種情況)。怎麼去控制的取值呢,就是第二項的作用了。也就是說,我們在計算Loss的時候,可以引入一個額外的項,這個額外的項與損失函式無關,但可以起到控制樣本貢獻的作用。關於第二項的具體形式以及各種變種,可以參考文章[3]。
(1)
這裡多說幾句,的取值與以及損失函式有關。一般稱為age,常常作為損失函式的某種上界來控制的取值。控制損失函式的大小來選擇sample的思想其實很常見。大家可以想到object detection領域常用的hard mining,以及最新Kaiming組提到的Focal loss。這裡有必要說下Focal loss裡得到的一個非常重要的結論。影響物體檢測的關鍵因素是負樣本數量太大,這些負樣本佔總loss比例大,而且多是容易分類的。這些容易分類的loss其實用處不大,但完全忽略會丟失資訊,所以引入Focal Loss。對比下這些想法,不知道能不能促生出self-paced Learning新的發現。扯得有點遠了...
  • 主要工作論述。為了避免co-training在訓練過程中引入的錯誤的pseudo-label資料,本文構建了公式(2)的目標函式。第一項就是傳統的損失函式,第二項是正則項。第三項是本文的重點。通過Self-paced Learning對當前生產pseudo-label的樣本進行選擇。第四項是上面說的用來控制的。最後一項也是正則項,用來鼓勵view之間的判斷的一致性(an unlabeled sample is likely to be labeled correctly or wrongly simultaneously for both views)。求解是常用的Alternative Optimization演算法。
(2)

文章就講到這裡吧,大家不懂的地方可以看看文章,或者留言一塊兒交流。

程式碼解析

這裡說下對程式碼的理解吧。程式碼改自Facebook Tong Xiao開源的open-reid。整個框架整潔、幹練,是個不錯的學習資源。框架的詳細資訊可以參考Sphinx生成的線上文件。由於需要deep learning學習feature,框架需要依賴於PyTorch(version >= 0.2.0)的支援。此外,還呼叫了umass大學 ALL lab開源metric-learn。這裡我們主要以馬凡師兄的修改版進行講解

  • 框架整體架構。
$ tree -d -L 2
.
├── docs
│   ├── examples
│   ├── figures
│   ├── notes
│   └── _static
├── examples
│   └── data
├── reid
│   ├── datasets
│   ├── evaluation_metrics
│   ├── feature_extraction
│   ├── loss
│   ├── metric_learning
│   ├── models
│   └── utils
└── test
    ├── datasets
    ├── evaluation_metrics
    ├── feature_extraction
    ├── loss
    ├── models
    └── utils

框架的主要內容在reid folder下,examples folder下給出了一些框架的實際用例(spaco.py就是論文的實現程式碼)。docs和test下分別是Sphinx文件和test用例。

下面我們從主函式入手,對論文的實現程式碼給出說明。

  • 程式碼詳細解析(examples/spaco.py為入口)
# 82-99行
if __name__ == '__main__': => main(parser.parse_args())
# 主要呼叫argparse軟體包,使程式可以接收命令列傳遞進來的引數。
# 注意:由於--batch-size需要專門配置,所以在此實現中這項引數並沒有應用。

p.s. 翻開一看,嚇了一跳,上次編輯已經11天前了。當然期間發生了一些事情,還參加了MLA2017。今天先寫一部分,明天早上必須完成。立帖為證!!!

# 71-74行
def main(args): => dataset = datasets.create(args.dataset, dataset_dir)
# 這裡我們主要關注74行。
# 由於re-id任務有很多資料集,但它們都以不同的格式存在,所以需要我們把它統一起來。
# 以這個為目的,作者實現了統一的資料層程式碼(見reid/datasets/)。為了能通過引數
# 值來接觸到不同的資料集,作者定義了入口檔案reid/datasets/__init__.py,從而
# 統一了資料的獲取形式,方式是建立__factory字典,形成string到class的對映。我
# 們重點關注Market1501_STD資料集,這個資料集其實是馬凡師兄對照原作Market1501
# 自己加的。為什麼不用原作者的,馬凡師兄的解釋是Market1501的query樣本和一般論文
# 的標準測試不一樣,其實就是修改了download函式病過載了load函式。

# reid/datasets/Market1501_STD.py 
# 進入Market1501_STD後,我們看建構函式。建構函式主要是執行download和load函式。
# 首先是download函式,由於有些資料集需要複雜的url parse等操作,所以作者給出url,
# 希望你自己下載好了再來執行(當然一些簡單的可以在程式裡download)。91-118主要是
# 判斷檔案完整和解壓縮。120-151行主要是對檔案進行統一命名,並且儲存成統一的檔案結
# 構。153-163主要儲存些資料集的原資訊,包括meta.json:資料集名稱、相機個數、
# single/multiple(一般涉及tracking) shot。splits.json:trainval/test
# 的檔名稱(路徑)、query以及gallery的檔名稱(路徑)。
# 對比Market1501.py可以看出,馬凡師兄修改了splits.json儲存的內容,Market1501.py
# 儲存的是圖片的pid,而馬凡師兄儲存的是檔名,更詳細了些。這個具體的作用,我們後面對比
# 看下。
# 然後是load函式。由於download函式有變化,load函式也需要改變,所以師兄過載了load函式。
# 43-47匯入上面儲存的splits.json。重要的地方就是傳進來的引數num_val,他其實是train/val
# split的一個分界點。由於本文關注的半監督學習,所以師兄沒有實現train,而是在後續code
#(spaco.py 26行)實現了train/untrain的split。untrain就是co-training方法的
# unlabeled的資料集倉庫。val和trainval都根據pid從identities全集裡取(val_pids是
# trainval_pids的後num_val個數據)。接下來67-74就是為什麼師兄重寫Market1501.py的
# 原因了。原因在於identities庫是以pid索引的,而有些test的pid和query pid相等,導致取
# query資料集時,會誤取一些test資料集。

# 以上內容其實把整個dataset包都講了,其他資料集也是類似的。

# 75-80行
model_names = [args.arch1, args.arch2] => args.gamma,args.train_ratio)
# 讀取model name-方便後面根據名字得到model 和save path-方便後面存log

接下來就可以進入code的核心spaco函數了。

# 12-28行
def spaco(...) => num_classes = data.num_trainval_ids
# 首先解釋下幾個引數吧。iter_step就是我們co-training交替迭代的次數。gamma比較重要,
# 後面講。train_ratio就是train/untrain中train的比例,以適應半監督學習。
# 重點講講26行split_dataset的實現。實現比較特殊,劃分train/untrain時,讓每個人的影象
# 在train/untrain中都有出現。具體直接看code就可以。

p.s.後面就是文章的重點實現了,今晚先到這裡吧。

# 33-40行
add_ratio = 0.5 => pred_y = np.argmax(sum(pred_probs), axis=1)
# 這幾行其實是第一次co-training迭代,目的是對兩個model(比如resnet50,densenet121)
# 進行訓練。add_ratio引數和pred_probs的意義後面講,直接看37行。這裡遍歷num_view,
# 分別對網路執行train_predict。咱們看看train_predict,這就需要進入models軟體包裡。

# reid/models/model_utils.py
# 進入train_predict函式後,首先執行131行的train。咱們需要進入函式train去詳細瞭解下。
# 首先執行get_model_by_name,這裡其實和dataset包一致,也是為了為多個model提供統一的
# 入口,在reid/models/__init__.py定義了__factory字典,達到由model_name到model的
# 對映。可以看到,在reid/models裡原作者實現了一般的model介面。

# 下面我們以reid/models/densenet.py為例,講講作者的網路搭建技巧。32行可以看到,
# 作者又定義了一個__factory字典,達到由model depth得到model的目的。有趣的是,這些
# 常見的model已經在torchvision裡包含了(讚揚一個pytorch)。由於原作者想把densenet
# 等網路當做base architecture。所以這裡重新實現了forword函式。對比forword函式以及
# 37-62行,可以發現,這裡加了一個全連線層,支援固定個數的num_features輸出,並對映到
# 類的個數空間裡。

# 繼續講解reid/models/model_utils.py的train函式107-110是得到訓練網路的資料層和
# 引數,這時,我們會發現spaco.py傳入的--batch-size其實沒有用。重點講109行。這裡就
# 是對資料進行一些常見的變換了,目的就是達到資料增廣的效果。進入
# reid/utils/data/data_process.py 直接看函式get_dataloader。這裡重點看26行,
# 傳入DataLoader的第一個引數Preprocessor,這是一個類,這個類比較特殊的地方就是必
# 須實現__getitem__方法,以便於DataLoader進行資料獲取。想深度瞭解的,其實看看pytorch
# 實現的DataLoader,以便於為自己後續開發積累。回到train函式,接下來就是訓練網路了
# (函式train_model),這個實現比較簡單,大家自己就能看懂。
# 到這裡reid/models/model_utils.py裡的train函式就講完了

# 回到train_predict,發現緊接著就是根據上面訓練好的網路進行預測。重點看
# reid/models/model_utils.pyde 133行。這裡傳入untrain data,用訓練好的網路對
# untrain data進行預測,以便服務於co-training。
# 下面進入predict_prob進行詳細說明。首先是get_feature函式,這個主要涉及untrain data
# 的資料封裝(創給dataloader),然後把網路的預測拿到。具體怎麼拿到的,其實可以看
# torch/nn/modules/module.py裡的__call__函式。123-126行就是用softmax把這些
# feature對映到概率空間。最終得到untrain data的概率預測。

# 至此train_predict全部函式講完。回到主函式的39行

這裡需要把實現與演算法對上,所以貼上演算法圖

Update vkUpdate wUpdate ykAlgorithm

上面我們主要講了1-5步,當我們進入第八步迭代時,檢視公式(4)會發現需要通過控制loss對進行更新。那我們實現上是不是必須先計算loss?答案是否定的,由於公式4的目的在於控制loss大的樣本不被選中,loss大對應預測概率低,所以這裡不需要計算loss,而是簡單的對預測概率進行操作就可以了。有了這樣的認識,我們繼續進行code講解就變得容易了。

# 39行
add_ids.append(dp.sel_idx(pred_probs[view], train_data, add_ratio))
# 我們直接跳進dp.sel_idx裡。
# 這裡我們會發現add_ratio是每次從untrain data新增進train data的比例。這個函式實現
# 也很容易理解。就是根據上面得到的pred_prob,對預測概率進行排序,然後取前面ratio個數據
# 以便加進來,為另一個view下次訓練服務。

接下來看演算法的第9步,這一步的目的是選擇最終被拿來當做訓練資料的untrain data。45-47的目的是在另一個view對untrain data選擇的基礎上選擇出符合本view model預測效果的資料。方式也是用dp.sel_idx對概率進行排序選擇。這裡需要注意的是,為了增加最終加入train data的untrain data個數,先對預測概率加了一個gamma,引數的另一個意義就是對照演算法的age。接著50-53行對train data進行更新,並繼續co-training訓練。56-63行為下次迭代選擇樣本服務。

到這裡,co-training的演算法基本上講完了。接下來結合code講講re-id的評價標準。由於最近才接觸re-id,對其評價標準極其不瞭解,我就查查資料,慢慢講吧。

評價指標

從code層面講,評價指標引入了metric learning。具體來說就是把算softmax以得到概率之前的logit當做feature,他們的真實label當做監督資訊,來訓練一個metric。這個metric可以是metric learning領域的任何metric,前文也提到了,作者主要引入了metric-learn開源包。至於為什麼把logit當做feature,以及為什麼用euclidean distance(code中用的這個,這個就不用metric learning),這個我現在還沒想明白。為什麼沒想明白,是因為訓練時用的CrossEntropyLoss,沒有涉及兩個樣本的互動(具體code可見reid/trainers.py,上面沒詳細講這裡)。

撇開上面的疑問,我們繼續看評價指標。主要看reid/evaluators.py evaluate_all函式。

這個函式的輸入就是我們計算好的query和gallery的parewise的distmat。這裡主要是82行計算map(mean average precision)以及96行計算cmc(Cumulative Matching Characteristics ) score。

  • mAP(reid/evaluation_metrics/ranking.py 函式mean_ap):101行對distmat的第一維(從0計數)進行排序,得到每個query的距離排序。然後遍歷m個query,Filter out the same id and same camera(107-109),呼叫sklearn.metrics的average_precision_score進行ap計算,最後得到mean ap。我個人的理解average_precision刻畫了precision與recall的對應關係,一般recall越大,precision越低。理由在於recall代表ranking序列中正樣本的佔總正樣本的比例,如果選擇的樣本都是正樣本的話,也就是說前面的ranking都是正樣本,其實不太符合邏輯,此時對負樣本的預測效果就會很差,所以precision比較低。更是詳細的可以參考知乎mean average precision(MAP)在計算機視覺中是如何計算和應用的?
  • cmcreid/evaluation_metrics/ranking.py 函式cmc:cmc主要從ranking角度評價re-id model的效能。具體地,給定query,檢視相同的gallery在topk的ranking中是否出現。作者給出的文件解釋在這。實現角度還是針對distmat的第一維排序(從0計數),得到每個query的距離排序。由於re-id分single shot 和multi shot(tracking),所以評價方式不一樣。single shot的實現需要對每個gallery_id sample一個gallery去評判結果是否在ranking中出現。

Poster

文中的證明

這個有時間我來分析下吧。

[1] Fan Ma, Deyu Meng*, Qi Xie, Zina Li, Xuanyi Dong,Self-paced Cotraining, ICML, 2017. [supplementary material][code][github link]

[2]Co-training 初探快切入

[3] Lu Jiang, Deyu Meng, Teruko Mitamura, Alexander Hauptmann.Easy Samples First: Self-paced Reranking for Zero-Example Multimedia Search. ACM MM. 2014.Slides.

編輯於 2017-11-12 Large-scale Machine Learning

推薦閱讀

Self-Driving Database

2020.4.5更新:補充一些最新的研究趨勢 今年的ICDE2020專門安排了一個 Self-Managing Database SystemsLearned Index(也不一定就是替換index,別的 具有經驗性規則的模組也行)https://arx…

PRO Pentium

一文入門元學習(Meta-Learning)(附程式碼)

涼爽的安迪發表於深度瞎學

Multi-task Learning and Beyond: 過去,現在與未來

劉詩昆

轉載|Multi-task Learning and Beyond: 過去,現在與未來

hyp12...發表於AI Bo...

9 條評論

寫下你的評論...
  • david2018-05-14 寫的不錯啊,留個微信或者qq交流吧,我們電信學院的
    • gmgl2019-12-08 求後續!感覺程式碼還是看不太懂啊,剛入門[流淚]
    • Velproqcr回覆gmgl02-29

      層主,內個實驗程式碼你找到了嗎?按照樓主的提示,我的examples沒有spaco.py

    • gmgl回覆Velproqcr02-29 找到了,但是還是執行不了,其中有個函式已經被棄用[捂臉],我沒解決
  • 樹偽神05-08

    首先點贊。

    請問,步驟2中的w(represents the model parameter inside the decision function g)到底是指什麼?是指所net的所有權重嗎,好像在程式碼裡也沒看到是怎麼初始化的

  • 葛葛08-20

    請問在更新偽標籤y的時候,公式6中存在引數v,如果對於某個為0的v,會導致公式6直接為0,那這個樣本的偽標籤怎麼計算呢?