CAN:藉助先驗分佈提升分類效能的簡單後處理技巧
顧名思義,本文將會介紹一種用於分類問題的後處理技巧——CAN(Classification with Alternating Normalization),出自論文《When in Doubt: Improving Classification Performance with Alternating Normalization》。經過筆者的實測,CAN確實多數情況下能提升多分類問題的效果,而且幾乎沒有增加預測成本,因為它僅僅是對預測結果的簡單重新歸一化操作。
有趣的是,其實CAN的思想是非常樸素的,樸素到每個人在生活中都應該用過同樣的思想。然而,CAN的論文卻沒有很好地說清楚這個思想,只是純粹形式化地介紹和實驗這個方法。本文的分享中,將會盡量將演算法思想介紹清楚。
思想例子#
假設有一個二分類問題,模型對於輸入aa給出的預測結果是p(a)=[0.05,0.95]p(a)=[0.05,0.95],那麼我們就可以給出預測類別為11;接下來,對於輸入bb,模型給出的預測結果是p(b)=[0.5,0.5]p(b)=[0.5,0.5],這時候處於最不確定的狀態,我們也不知道輸出哪個類別好。
但是,假如我告訴你:1、類別必然是0或1其中之一;2、兩個類別的出現概率各為0.5。在這兩點先驗資訊之下,由於前一個樣本預測結果為1,那麼基於樸素的均勻思想,我們是否更傾向於將後一個樣本預測為0,以得到一個滿足第二點先驗的預測結果?
這樣的例子還有很多,比如做10道選擇題,前9道你都比較有信心,第10題完全不會只能瞎蒙,然後你一看發現前9題選A、B、C的都有就是沒有一個選D的,那麼第10題在蒙的時候你會不會更傾向於選D?
這些簡單例子的背後,有著跟CAN同樣的思想,它其實就是用先驗分佈來校正低置信度的預測結果,使得新的預測結果的分佈更接近先驗分佈。
不確定性#
準確來說,CAN是針對低置信度預測結果的後處理手段,所以我們首先要有一個衡量預測結果不確定性的指標。常見的度量是“熵”,對於p=[p1,p2,⋯,pm]p=[p1,p2,⋯,pm],定義為:
然而,雖然熵是一個常見選擇,但其實它得出的結果並不總是符合我們的直觀理解。比如對於p(a)=[0.5,0.25,0.25]p(a)=[0.5,0.25,0.25]和p(b)=[0.5,0.5,0]p(b)=[0.5,0.5,0],直接套用公式得到
一個簡單的修正是隻用前top-kk個概率值來算熵,不失一般性,假設p1,p2,⋯,pkp1,p2,⋯,pk是概率最高的kk個值,那麼
其中p~i=pi/∑i=1kpip~i=pi/∑i=1kpi。為了得到一個0~1範圍內的結果,我們取Htop-k(p)/logkHtop-k(p)/logk為最終的不確定性指標。
演算法步驟#
現在假設我們有NN個樣本需要預測類別,模型直接的預測結果是NN個概率分佈p(1),p(2),⋯,p(N)p(1),p(2),⋯,p(N),假設測試樣本和訓練樣本是同分布的,那麼完美的預測結果應該有:
其中p~p~是類別的先驗分佈,我們可以直接從訓練集估計。也就是說,全體預測結果應該跟先驗分佈是一致的,但受限於模型效能等原因,實際的預測結果可能明顯偏離上式,這時候我們就可以人為修正這部分。
具體來說,我們選定一個閾值ττ,將指標小於ττ的預測結果視為高置信度的,而大於等於ττ的則是低置信度的,不失一般性,我們假設前nn個結果p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)屬於高置信度的,而剩下的N−nN−n個屬於低置信度的。我們認為高置信度部分是更加可靠的,所以它們不用修正,並且可以用它們來作為“標準參考系”來修正低置信度部分。
具體來說,對於∀j∈{n+1,n+2,⋯,N}∀j∈{n+1,n+2,⋯,N},我們將p(j)p(j)與高置信度的p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)一起,執行一次“行間”標準化:
這裡的k∈{1,2,⋯,n}∪{j}k∈{1,2,⋯,n}∪{j},其中乘除法都是element-wise的。不難發現,這個標準化的目的是使得所有新的p(k)p(k)的平均向量等於先驗分佈p~p~,也就是促使式(3)(3)的成立。然而,這樣標準化之後,每個p(k)p(k)就未必滿足歸一化了,所以我們還要執行一次“行內”標準化:
但這樣一來,式(3)(3)可能又不成立了。所以理論上我們可以交替迭代執行這兩步,直到結果收斂(不過實驗結果顯示一般情況下一次的效果是最好的)。最後,我們只保留最新的p(j)p(j)作為原來第jj個樣本的預測結果,其餘的p(k)p(k)均棄之不用。
注意,這個過程需要我們遍歷每個低置信度結果j∈{n+1,n+2,⋯,N}j∈{n+1,n+2,⋯,N}執行,也就是說是逐個樣本進行修正,而不是一次性修正的,每個p(j)p(j)都藉助原始的高置信度結果p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)組合來按照上述步驟迭代,雖然迭代過程中對應的p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)都會隨之更新,但那只是臨時結果,最後都是棄之不用的,每次修正都是用原始的p(1),p(2),⋯,p(n)p(1),p(2),⋯,p(n)。
參考實現#
這是筆者給出的參考實現程式碼:
# 預測結果,計算修正前準確率
y_pred = model.predict(
valid_generator.fortest(), steps=len(valid_generator), verbose=True
)
y_true = np.array([d[1] for d in valid_data])
acc_original = np.mean([y_pred.argmax(1) == y_true])
print('original acc: %s' % acc_original)
# 評價每個預測結果的不確定性
k = 3
y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
y_pred_uncertainty = -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)
# 選擇閾值,劃分高、低置信度兩部分
threshold = 0.9
y_pred_confident = y_pred[y_pred_uncertainty < threshold]
y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold]
y_true_confident = y_true[y_pred_uncertainty < threshold]
y_true_unconfident = y_true[y_pred_uncertainty >= threshold]
# 顯示兩部分各自的準確率
# 一般而言,高置信度集準確率會遠高於低置信度的
acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean()
acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean()
print('confident acc: %s' % acc_confident)
print('unconfident acc: %s' % acc_unconfident)
# 從訓練集統計先驗分佈
prior = np.zeros(num_classes)
for d in train_data:
prior[d[1]] += 1.
prior /= prior.sum()
# 逐個修改低置信度樣本,並重新評價準確率
right, alpha, iters = 0, 1, 1
for i, y in enumerate(y_pred_unconfident):
Y = np.concatenate([y_pred_confident, y[None]], axis=0)
for j in range(iters):
Y = Y**alpha
Y /= Y.mean(axis=0, keepdims=True)
Y *= prior[None]
Y /= Y.sum(axis=1, keepdims=True)
y = Y[-1]
if y.argmax() == y_true_unconfident[i]:
right += 1
# 輸出修正後的準確率
acc_final = (acc_confident * len(y_pred_confident) + right) / len(y_pred)
print('new unconfident acc: %s' % (right / (i +