1. 程式人生 > 其它 >(九)混淆矩陣與繪圖

(九)混淆矩陣與繪圖

一、基本概念

當說到召回率的時候就說到了混淆矩陣。

再回顧一下召回率吧,案例中有100個正例,猜中(預測對)了59個,我們就說召回率為59%。

召回率就是猜中率。

當時也講到,正例和反例,加上猜中和猜錯,總共有四種情況


所謂召回率,僅僅是其中的四分之一。在條件允許(資本充足)的情況下,我們關心的,也是實際有用的,的確是召回率。

但是實際條件並不允許我們這麼單一,現實對我們的要求不僅是增加猜中的概率,也需要降低猜錯的概率。

同時,關鍵的一個隱蔽點,在於數量的限制,50個男生,50個女生,我猜全部是男生,就會發現這種奇葩情況:

召回率100%,但是其他分佈慘不忍睹。

隱藏的,就是可以猜的個數。

當然,我們可以把猜的個數做一個限制,但是這只是在已知的情境下才有具體的作用,位置的情況下,誰也說不準100個人中到底有多少個男生,多少個女生,可取的範圍的確是[0,100]。

綜上所述,對於一個模型的評估,所謂的召回率只能是在其他情況下都"不太差"的情況下才有對比的意義,或者說是隻在乎"召回率",也就是錯殺一千也不放過一個,不在乎浪費和消耗的情況下才有追逐的價值。

普遍的情況,追求的當然是全面,用最少的資源做最多的事情。也就是說,我們需要對樣本的分佈和預測的分佈進行綜合的考量,從各方面對模型進行評估和約束,才能夠達到預期的目標。

而上面的2*2的分佈表格,就是我們所謂的混淆矩陣。

當樣本分佈為3類的時候,猜測也為3類


其他先不管,至少我們可以先得出一個結論:

混淆矩陣始終是方陣。

把對錯繼續劃分,樣本除了猜對和猜錯,具體可以劃分為猜對,猜成?類,資料中類越多,這種也就更加具體規範。

混淆矩陣的意義在於彌補錯(and)誤,我們也要明晰這個誤區:不是成功率提高了錯誤率就會降低。

或許更具體的說起,就是我們的"成功"也是有水分的,TP是成功,TN是FP水分,TN是錯漏,FT是有效排除。

真實的結果不僅在於找到對的,還在於排除錯的。單方面的前進,或許會覆蓋正確,但是也會錯過正確,偏離正確。

只有兩頭逼近,才能夠真正的定位正確,或鎖定在一個較小的區間範圍內(夾逼定則)。

二、函式

1、召回率

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
 
guess = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
fact = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
a = recall_score(guess, fact)
print(a)
'''
0.0
'''

這個會自動計算,主要的過程就是"對號入座"

(1)資料對
(2)位置對
首先,它會把實際結果和預測結果組成資料對,才到後來的判斷階段。

判斷的時候,只考慮對錯,按照如下的表進行計算。


全部的數組合成一個個資料對,然後按照這種分佈情況表進行統計,正對角線上的都是預測正確的,這就加一,最後正確數的除以總數就得出來了所謂的召回率。

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
 
guess = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
fact = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
a = confusion_matrix(guess, fact)
print(a)
'''
[[0 5]
 [5 0]]
'''

圖表表示一下

恩...說了半天都覺得混淆矩陣比召回率高階多了,事實打臉了,的確是有了混淆矩陣(混淆統計)才計算召回率的。

即使是召回率感覺高階一些,但是混淆矩陣更詳細,這才是避免更大失誤的關注點。

尤其是多種分類的情況下

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
 
guess = [1, 0, 1, 2, 1, 0, 1, 0, 1, 0]
fact = [0, 1, 0, 1, 2, 1, 0, 1, 0, 1]
a = confusion_matrix(guess, fact)
print(a)
'''
[[0 4 0]
 [4 0 1]
 [0 1 0]]
'''

三、繪圖

混淆矩陣重要吧,不過誰知道啊,誰關心啊,資料人家感觸不到,也不一定深刻理解,怎麼辦,畫圖唄。

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
import matplotlib.pyplot as plt
 
guess = [1, 0, 1]
fact = [0, 1, 0]
classes = list(set(fact))
classes.sort()
confusion = confusion_matrix(guess, fact)
plt.imshow(confusion, cmap=plt.cm.Blues)
indices = range(len(confusion))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('guess')
plt.ylabel('fact')
for first_index in range(len(confusion)):
    for second_index in range(len(confusion[first_index])):
        plt.text(first_index, second_index, confusion[first_index][second_index])
 
plt.show()

複雜一點

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
import matplotlib.pyplot as plt
 
guess = [1, 0, 1, 2, 1, 0, 1, 0, 1, 0]
fact = [0, 1, 0, 1, 2, 1, 0, 1, 0, 1]
classes = list(set(fact))
classes.sort()
confusion = confusion_matrix(guess, fact)
plt.imshow(confusion, cmap=plt.cm.Blues)
indices = range(len(confusion))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('guess')
plt.ylabel('fact')
for first_index in range(len(confusion)):
    for second_index in range(len(confusion[first_index])):
        plt.text(first_index, second_index, confusion[first_index][second_index])
 
plt.show()

講解一波

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
import matplotlib.pyplot as plt
 
 
# 預測資料,predict之後的預測結果集
guess = [1, 0, 1, 2, 1, 0, 1, 0, 1, 0]
# 真實結果集
fact = [0, 1, 0, 1, 2, 1, 0, 1, 0, 1]
# 類別
classes = list(set(fact))
# 排序,準確對上分類結果
classes.sort()
# 對比,得到混淆矩陣
confusion = confusion_matrix(guess, fact)
# 熱度圖,後面是指定的顏色塊,gray也可以,gray_x反色也可以
plt.imshow(confusion, cmap=plt.cm.Blues)
# 這個東西就要注意了
# ticks 這個是座標軸上的座標點
# label 這個是座標軸的註釋說明
indices = range(len(confusion))
# 座標位置放入
# 第一個是迭代物件,表示座標的順序
# 第二個是座標顯示的數值的陣列,第一個表示的其實就是座標顯示數字陣列的index,但是記住必須是迭代物件
plt.xticks(indices, classes)
plt.yticks(indices, classes)
# 熱度顯示儀?就是旁邊的那個驗孕棒啦
plt.colorbar()
# 就是座標軸含義說明了
plt.xlabel('guess')
plt.ylabel('fact')
# 顯示資料,直觀些
for first_index in range(len(confusion)):
    for second_index in range(len(confusion[first_index])):
        plt.text(first_index, second_index, confusion[first_index][second_index])
 
# 顯示
plt.show()
 
# PS:注意座標軸上的顯示,就是classes
# 如果資料正確的,對應關係顯示錯了就功虧一簣了
# 一個錯誤發生,想要說服別人就更難了