1. 程式人生 > 實用技巧 >自訓練和半監督學習介紹

自訓練和半監督學習介紹

作者|Doug Steen
編譯|VK
來源|Towards Data Science

當涉及到機器學習分類任務時,用於訓練演算法的資料越多越好。在監督學習中,這些資料必須根據目標類進行標記,否則,這些演算法將無法學習獨立變數和目標變數之間的關係。但是,在構建用於分類的大型標記資料集時,會出現兩個問題:

  1. 標記資料可能很耗時。假設我們有1000000張狗影象,我們想將它們輸入到分類演算法中,目的是預測每個影象是否包含波士頓狗。如果我們想將所有這些影象用於監督分類任務,我們需要一個人檢視每個影象並確定是否存在波士頓狗。
  2. 標記資料可能很昂貴。原因一:要想讓人費盡心思去搜100萬張狗狗照片,我們可能得掏錢。

那麼,這些未標記的資料可以用在分類演算法中嗎?

這就是半監督學習的用武之地。在半監督方法中,我們可以在少量的標記資料上訓練分類器,然後使用該分類器對未標記的資料進行預測。

由於這些預測可能比隨機猜測更好,未標記的資料預測可以作為“偽標籤”在隨後的分類器迭代中採用。雖然半監督學習有很多種風格,但這種特殊的技術稱為自訓練。

自訓練

在概念層面上,自訓練的工作原理如下:

步驟1:將標記的資料例項拆分為訓練集和測試集。然後,對標記的訓練資料訓練一個分類演算法。

步驟2:使用經過訓練的分類器來預測所有未標記資料例項的類標籤。在這些預測的類標籤中,正確率最高的被認為是“偽標籤”。

(第2步的幾個變化:a)所有預測的標籤可以同時作為“偽標籤”使用,而不考慮概率;或者b)“偽標籤”資料可以通過預測的置信度進行加權。)

步驟3:將“偽標記”資料與正確標記的訓練資料連線起來。在組合的“偽標記”和正確標記訓練資料上重新訓練分類器。

步驟4:使用經過訓練的分類器來預測已標記的測試資料例項的類標籤。使用你選擇的度量來評估分類器效能。

(可以重複步驟1到4,直到步驟2中的預測類標籤不再滿足特定的概率閾值,或者直到沒有更多未標記的資料保留。)

好的,明白了嗎?很好!讓我們通過一個例子解釋。

示例:使用自訓練改進分類器

為了演示自訓練,我使用Python和surgical_deepnet 資料集,可以在Kaggle上找到:https://www.kaggle.com/omnamahshivai/surgical-dataset-binary-classification

此資料集用於二分類,包含14.6k+手術的資料。這些屬性是bmi、年齡等各種測量值,而目標變數complexing則記錄患者是否因手術而出現併發症。顯然,能夠準確地預測患者是否會因手術而出現併發症,這對醫療保健和保險供應商都是最有利的。

匯入庫

對於本教程,我將匯入numpy、pandas和matplotlib。我還將使用sklearn中的LogisticRegression分類器,以及用於模型評估的f1_score和plot_confusion_matrix 函式

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.linear_model import LogisticRegression

from sklearn.metrics import f1_score
from sklearn.metrics import plot_confusion_matrix

載入資料

# 載入資料

df = pd.read_csv('surgical_deepnet.csv')
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14635 entries, 0 to 14634
Data columns (total 25 columns):
bmi                    14635 non-null float64
Age                    14635 non-null float64
asa_status             14635 non-null int64
baseline_cancer        14635 non-null int64
baseline_charlson      14635 non-null int64
baseline_cvd           14635 non-null int64
baseline_dementia      14635 non-null int64
baseline_diabetes      14635 non-null int64
baseline_digestive     14635 non-null int64
baseline_osteoart      14635 non-null int64
baseline_psych         14635 non-null int64
baseline_pulmonary     14635 non-null int64
ahrq_ccs               14635 non-null int64
ccsComplicationRate    14635 non-null float64
ccsMort30Rate          14635 non-null float64
complication_rsi       14635 non-null float64
dow                    14635 non-null int64
gender                 14635 non-null int64
hour                   14635 non-null float64
month                  14635 non-null int64
moonphase              14635 non-null int64
mort30                 14635 non-null int64
mortality_rsi          14635 non-null float64
race                   14635 non-null int64
complication           14635 non-null int64
dtypes: float64(7), int64(18)
memory usage: 2.8 MB

資料集中的屬性都是數值型的,沒有缺失值。由於我這裡的重點不是資料清理,所以我將繼續對資料進行劃分。

資料劃分

為了測試自訓練的效果,我需要將資料分成三部分:訓練集、測試集和未標記集。我將按以下比例拆分資料:

  • 1% 訓練
  • 25% 測試
  • 74% 未標記

對於未標記集,我將簡單地放棄目標變數complexing,並假裝它從未存在過。

所以,在這個病例中,我們認為74%的手術病例沒有關於併發症的資訊。我這樣做是為了模擬這樣一個事實:在實際的分類問題中,可用的大部分資料可能沒有類標籤。然而,如果我們有一小部分資料的類標籤(在本例中為1%),那麼可以使用半監督學習技術從未標記的資料中得出結論。

下面,我隨機化資料,生成索引來劃分資料,然後建立測試、訓練和未標記的劃分。然後我檢查各個集的大小,確保一切都按計劃進行。

X_train dimensions: (146, 24)
y_train dimensions: (146,)

X_test dimensions: (3659, 24)
y_test dimensions: (3659,)

X_unlabeled dimensions: (10830, 24)

類分佈

多數類的樣本數((併發症))是少數類(併發症)的兩倍多。在這樣一個不平衡的類的情況下,我想準確度可能不是最佳的評估指標。

選擇F1分數作為分類指標來判斷分類器的有效性。F1分數對類別不平衡的影響比準確度更為穩健,當類別近似平衡時,這一點更為合適。F1得分計算如下:

其中precision是預測正例中正確預測的比例,recall是真實正例中正確預測的比例。

初始分類器(監督)

為了使半監督學習的結果更真實,我首先使用標記的訓練資料訓練一個簡單的Logistic迴歸分類器,並對測試資料集進行預測。

Train f1 Score: 0.5846153846153846
Test f1 Score: 0.5002908667830134

分類器的F1分數為0.5。混淆矩陣告訴我們,分類器可以很好地預測沒有併發症的手術,準確率為86%。然而,分類器更難正確識別有併發症的手術,準確率只有47%。

預測概率

對於自訓練演算法,我們需要知道Logistic迴歸分類器預測的概率。幸運的是,sklearn提供了.predict_proba()方法,它允許我們檢視屬於任一類的預測的概率。如下所示,在二元分類問題中,每個預測的總概率總和為1.0。

array([[0.93931367, 0.06068633],
       [0.2327203 , 0.7672797 ],
       [0.93931367, 0.06068633],
       ...,
       [0.61940353, 0.38059647],
       [0.41240068, 0.58759932],
       [0.24306008, 0.75693992]])

自訓練分類器(半監督)

既然我們知道了如何使用sklearn獲得預測概率,我們可以繼續編碼自訓練分類器。以下是簡要概述:

第1步:首先,在標記的訓練資料上訓練Logistic迴歸分類器。

第2步:接下來,使用分類器預測所有未標記資料的標籤,以及這些預測的概率。在這種情況下,我只對概率大於99%的預測採用“偽標籤”。

第3步:將“偽標記”資料與標記的訓練資料連線起來,並在連線的資料上重新訓練分類器。

第4步:使用訓練好的分類器對標記的測試資料進行預測,並對分類器進行評估。

重複步驟1到4,直到沒有更多的預測具有大於99%的概率,或者沒有未標記的資料保留。

下面的程式碼使用while迴圈在Python中實現這些步驟。

Iteration 0
Train f1: 0.5846153846153846
Test f1: 0.5002908667830134
Now predicting labels for unlabeled data...
42 high-probability predictions added to training data.
10788 unlabeled instances remaining.

Iteration 1
Train f1: 0.7627118644067796
Test f1: 0.5037463976945246
Now predicting labels for unlabeled data...
30 high-probability predictions added to training data.
10758 unlabeled instances remaining.

Iteration 2
Train f1: 0.8181818181818182
Test f1: 0.505431675242996
Now predicting labels for unlabeled data...
20 high-probability predictions added to training data.
10738 unlabeled instances remaining.

Iteration 3
Train f1: 0.847457627118644
Test f1: 0.5076835515082526
Now predicting labels for unlabeled data...
21 high-probability predictions added to training data.
10717 unlabeled instances remaining.

...
Iteration 44
Train f1: 0.9481216457960644
Test f1: 0.5259179265658748
Now predicting labels for unlabeled data...
0 high-probability predictions added to training data.
10079 unlabeled instances remaining.

自訓練演算法經過44次迭代,就不能以99%的概率預測更多的未標記例項了。即使一開始有10,830個未標記的例項,在自訓練之後仍然有10,079個例項未標記(並且未被分類器使用)。

經過44次迭代,F1的分數從0.50提高到0.525!雖然這只是一個小的增長,但看起來自訓練已經改善了分類器在測試資料集上的效能。上圖的頂部面板顯示,這種改進大部分發生在演算法的早期迭代中。同樣,底部面板顯示,新增到訓練資料中的大多數“偽標籤”都是在前20-30次迭代中出現的。

最後的混淆矩陣顯示有併發症的手術分類有所改善,但沒有併發症的手術分類略有下降。有了F1分數的提高,我認為這是一個可以接受的進步-可能更重要的是確定會導致併發症的手術病例(真正例),並且可能值得增加假正例率來達到這個結果。

警告語

所以你可能會想:用這麼多未標記的資料進行自訓練有風險嗎?答案當然是肯定的。請記住,儘管我們將“偽標記”資料與標記的訓練資料一起包含在內,但某些“偽標記”資料肯定會不正確。當足夠多的“偽標籤”不正確時,自訓練演算法會強化糟糕的分類決策,而分類器的效能實際上會變得更糟。

可以使用分類器在訓練期間沒有看到的測試集,或者使用“偽標籤”預測的概率閾值,可以減輕這種風險。

原文連結:https://towardsdatascience.com/a-gentle-introduction-to-self-training-and-semi-supervised-learning-ceee73178b38

歡迎關注磐創AI部落格站:
http://panchuang.net/

sklearn機器學習中文官方文件:
http://sklearn123.com/

歡迎關注磐創部落格資源彙總站:
http://docs.panchuang.net/