1. 程式人生 > 其它 >CAN:藉助先驗分佈提升分類效能的簡單後處理技巧

CAN:藉助先驗分佈提升分類效能的簡單後處理技巧

顧名思義,本文將會介紹一種用於分類問題的後處理技巧——CAN(Classification with Alternating Normalization),出自論文《When in Doubt: Improving Classification Performance with Alternating Normalization》。經過筆者的實測,CAN確實多數情況下能提升多分類問題的效果,而且幾乎沒有增加預測成本,因為它僅僅是對預測結果的簡單重新歸一化操作。

有趣的是,其實CAN的思想是非常樸素的,樸素到每個人在生活中都應該用過同樣的思想。然而,CAN的論文卻沒有很好地說清楚這個思想,只是純粹形式化地介紹和實驗這個方法。本文的分享中,將會盡量將演算法思想介紹清楚。

思想例子#

假設有一個二分類問題,模型對於輸入aa給出的預測結果是p(a)=[0.05,0.95]p(a)=[0.05,0.95],那麼我們就可以給出預測類別為11;接下來,對於輸入bb,模型給出的預測結果是p(b)=[0.5,0.5]p(b)=[0.5,0.5],這時候處於最不確定的狀態,我們也不知道輸出哪個類別好。

但是,假如我告訴你:1、類別必然是0或1其中之一;2、兩個類別的出現概率各為0.5。在這兩點先驗資訊之下,由於前一個樣本預測結果為1,那麼基於樸素的均勻思想,我們是否更傾向於將後一個樣本預測為0,以得到一個滿足第二點先驗的預測結果?

這樣的例子還有很多,比如做10道選擇題,前9道你都比較有信心,第10題完全不會只能瞎蒙,然後你一看發現前9題選A、B、C的都有就是沒有一個選D的,那麼第10題在蒙的時候你會不會更傾向於選D?

這些簡單例子的背後,有著跟CAN同樣的思想,它其實就是用先驗分佈來校正低置信度的預測結果,使得新的預測結果的分佈更接近先驗分佈。

不確定性#

準確來說,CAN是針對低置信度預測結果的後處理手段,所以我們首先要有一個衡量預測結果不確定性的指標。常見的度量是“熵”,對於p=[p1,p2,,pm]p=[p1,p2,⋯,pm],定義為:

H(p)=i=1mpilogpi(1)H(p)=−∑i=1mpilog⁡pi


然而,雖然熵是一個常見選擇,但其實它得出的結果並不總是符合我們的直觀理解。比如對於p(a)=[0.5,0.25,0.25]p(a)=[0.5,0.25,0.25]和p(b)=[0.5,0.5,0]p(b)=[0.5,0.5,0],直接套用公式得到

H(p(a))>H(p(b))H(p(a))>H(p(b)),但就我們的分類場景而言,顯然我們會認為p(b)p(b)比p(a)p(a)更不確定,所以直接用熵還不夠合理。

一個簡單的修正是隻用前top-kk個概率值來算熵,不失一般性,假設p1,p2,,pkp1,p2,⋯,pk是概率最高的kk個值,那麼

Htop-k(p)=i=1kp~ilogp~i(2)Htop-k(p)=−∑i=1kp~ilog⁡p~i


其中p~i=pi/i=1kpip~i=pi/∑i=1kpi。為了得到一個0~1範圍內的結果,我們取Htop-k(p)/logkHtop-k(p)/log⁡k為最終的不確定性指標。

演算法步驟#

現在假設我們有NN個樣本需要預測類別,模型直接的預測結果是NN個概率分佈p(1),p(2),,p(N)p(1),p(2),⋯,p(N),假設測試樣本和訓練樣本是同分布的,那麼完美的預測結果應該有:

1Ni=1Np(i)=p~(3)(3)1N∑i=1Np(i)=p~


其中p~p~是類別的先驗分佈,我們可以直接從訓練集估計。也就是說,全體預測結果應該跟先驗分佈是一致的,但受限於模型效能等原因,實際的預測結果可能明顯偏離上式,這時候我們就可以人為修正這部分。

具體來說,我們選定一個閾值ττ,將指標小於ττ的預測結果視為高置信度的,而大於等於ττ的則是低置信度的,不失一般性,我們假設前nn個結果p(1),p(2),,p(n)p(1),p(2),⋯,p(n)屬於高置信度的,而剩下的NnN−n個屬於低置信度的。我們認為高置信度部分是更加可靠的,所以它們不用修正,並且可以用它們來作為“標準參考系”來修正低置信度部分。

具體來說,對於j{n+1,n+2,,N}∀j∈{n+1,n+2,⋯,N},我們將p(j)p(j)與高置信度的p(1),p(2),,p(n)p(1),p(2),⋯,p(n)一起,執行一次“行間”標準化

p(k)p(k)/p¯×p~,p¯=1n+1(p(j)+i=1np(i))(4)p(k)←p(k)/p¯×p~,p¯=1n+1(p(j)+∑i=1np(i))


這裡的k{1,2,,n}{j}k∈{1,2,⋯,n}∪{j},其中乘除法都是element-wise的。不難發現,這個標準化的目的是使得所有新的p(k)p(k)的平均向量等於先驗分佈p~p~,也就是促使式(3)(3)的成立。然而,這樣標準化之後,每個p(k)p(k)就未必滿足歸一化了,所以我們還要執行一次“行內”標準化

p(k)p(k)ii=1mp(k)i(5)(5)p(k)←pi(k)∑i=1mpi(k)


但這樣一來,式(3)(3)可能又不成立了。所以理論上我們可以交替迭代執行這兩步,直到結果收斂(不過實驗結果顯示一般情況下一次的效果是最好的)。最後,我們只保留最新的p(j)p(j)作為原來第jj個樣本的預測結果,其餘的p(k)p(k)均棄之不用。

注意,這個過程需要我們遍歷每個低置信度結果j{n+1,n+2,,N}j∈{n+1,n+2,⋯,N}執行,也就是說是逐個樣本進行修正,而不是一次性修正的,每個p(j)p(j)都藉助原始的高置信度結果p(1),p(2),,p(n)p(1),p(2),⋯,p(n)組合來按照上述步驟迭代,雖然迭代過程中對應的p(1),p(2),,p(n)p(1),p(2),⋯,p(n)都會隨之更新,但那只是臨時結果,最後都是棄之不用的,每次修正都是用原始的p(1),p(2),,p(n)p(1),p(2),⋯,p(n)。

參考實現#

這是筆者給出的參考實現程式碼:

# 預測結果,計算修正前準確率
y_pred = model.predict(
    valid_generator.fortest(), steps=len(valid_generator), verbose=True
)
y_true = np.array([d[1] for d in valid_data])
acc_original = np.mean([y_pred.argmax(1) == y_true])
print('original acc: %s' % acc_original)

# 評價每個預測結果的不確定性
k = 3
y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
y_pred_uncertainty = -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)

# 選擇閾值,劃分高、低置信度兩部分
threshold = 0.9
y_pred_confident = y_pred[y_pred_uncertainty < threshold]
y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold]
y_true_confident = y_true[y_pred_uncertainty < threshold]
y_true_unconfident = y_true[y_pred_uncertainty >= threshold]

# 顯示兩部分各自的準確率
# 一般而言,高置信度集準確率會遠高於低置信度的
acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean()
acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean()
print('confident acc: %s' % acc_confident)
print('unconfident acc: %s' % acc_unconfident)

# 從訓練集統計先驗分佈
prior = np.zeros(num_classes)
for d in train_data:
    prior[d[1]] += 1.

prior /= prior.sum()

# 逐個修改低置信度樣本,並重新評價準確率
right, alpha, iters = 0, 1, 1
for i, y in enumerate(y_pred_unconfident):
    Y = np.concatenate([y_pred_confident, y[None]], axis=0)
    for j in range(iters):
        Y = Y**alpha
        Y /= Y.mean(axis=0, keepdims=True)
        Y *= prior[None]
        Y /= Y.sum(axis=1, keepdims=True)
    y = Y[-1]
    if y.argmax() == y_true_unconfident[i]:
        right += 1

# 輸出修正後的準確率
acc_final = (acc_confident * len(y_pred_confident) + right) / len(y_pred)
print('new unconfident acc: %s' % (right / (i + 1.)))
print('final acc: %s' % acc_final)

實驗結果#

那麼,這樣的簡單後處理,究竟能帶來多大的提升呢?原論文給出的實驗結果是相當可觀的:

原論文的實驗結果之一

筆者也在CLUE上的兩個中文文字分類任務上做了實驗,顯示基本也有點提升,但沒那麼可觀(驗證集結果):

BERTBERT + CANRoBERTaRoBERTa + CANIFLYTEK(類別數:119)60.06%60.52%60.64%60.95%TNEWS(類別數:15)56.80%56.86%58.06%58.00%IFLYTEK(類別數:119)TNEWS(類別數:15)BERT60.06%56.80%BERT + CAN60.52%56.86%RoBERTa60.64%58.06%RoBERTa + CAN60.95%58.00%

大體上來說,類別數目越多,效果提升越明顯,如果類別數目比較少,那麼可能提升比較微弱甚至會下降(當然就算下降也是微弱的),所以這算是一個“幾乎免費的午餐”了。超引數選擇方面,上面給出的中文結果,只迭代了1次,kk的選擇為3、ττ的選擇為0.9,經過簡單的除錯,發現這基本上已經是比較優的引數組合了。

還有的讀者可能想問前面說的“高置信度那部分結果更可靠”這個情況是否真的成立?至少在筆者的兩個中文實驗上它是明顯成立的,比如IFLYTEK任務,篩選出來的高置信度集準確率為0.63+,而低置信度集的準確率只有0.22+;TNEWS任務類似,高置信度集準確率為0.58+,而低置信度集的準確率只有0.23+。

個人評價

最後再來綜合地思考和評價一下CAN。

首先,一個很自然的疑問是為什麼不直接將所有低置信度結果跟高置信度結果拼在一起進行修正,而是要逐個進行修正?筆者不知道原論文作者有沒有對比過,但筆者確實實驗過這個想法,結果是批量修正有時跟逐個修正持平,但有時也會下降。其實也可以理解,CAN本意應該是藉助先驗分佈,結合高置信度結果來修正低置信度的,在這個過程中,如果摻入越多的低置信度結果,那麼最終的偏差可能就越大,因此理論上逐個修正會比批量修正更為可靠。

說到原論文,讀過CAN論文的讀者,應該能發現本文介紹與CAN原論文大致有三點不同:

1、不確定性指標的計算方法不同。按照原論文的描述,它最終的不確定性指標計算方式應該是

1logmi=1kpilogpi(6)−1log⁡m∑i=1kpilog⁡pi


也就是說,它也是top-kk個概率算熵的形式,但是它沒有對這kk個概率值重新歸一化,並且它將其壓縮到0~1之間的因子是logmlog⁡m而不是logklog⁡k(因為它沒有重新歸一化,所以只有除logmlog⁡m才能保證0~1之間)。經過筆者測試,原論文的這種方式計算出來的結果通常明顯小於1,這不利於我們對閾值的感知和除錯。

2、對CAN的介紹方式不同。原論文是純粹數學化、矩陣化地陳述CAN的演算法步驟,而且沒有介紹演算法的思想來源,這對理解CAN是相當不友好的。如果讀者沒有自行深入思考演算法原理,是很難理解為什麼這樣的後處理手段就能提升分類效果的,而在徹底弄懂之後則會有一種故弄玄虛之感。

3、CAN的演算法流程略有不同。原論文在迭代過程中還引入了引數αα,使得式(4)(4)變為

p(k)[p(k)]α/p¯×p~,p¯=1n+1([p(j)]α+i=1n[p(i)]α)(7)p(k)←[p(k)]α/p¯×p~,p¯=1n+1([p(j)]α+∑i=1n[p(i)]α)


也就是對每個結果進行αα次方後再迭代。當然,原論文也沒有對此進行解釋,而在筆者看來,該引數純粹是為了調參而引入的(引數多了,總能把效果調到有所提升),沒有太多實際意義。而且筆者自己在實驗中發現,α=1α=1基本已經是最優選擇了,精調αα也很難獲得是實質收益。

文章小結#

本文介紹了一種名為CAN的簡單後處理技巧,它藉助先驗分佈來將預測結果重新歸一化,幾乎沒有增加多少計算成本就能提高分類效能。經過筆者的實驗,CAN確實能給分類效果帶來一定提升,並且通常來說類別數越多,效果越明顯。