使用keras的fit_generator來獲得混淆矩陣Confusion Matrix
阿新 • • 發佈:2018-12-09
還是google過來的方法,說明它還是挺靠譜滴。這裡有必要記錄一下。
關於混亂淆矩陣是用來幹嘛的,這裡就不說了。各位可以百度or谷歌。
關於如何使用fit_generator來進行訓練可以看我上一篇文章。
我們在使用fit_generator方法來進行訓練的時候,是不需要自己讀取x_img_train,y_label_train的。都是generator幫我們做好了。
但是要想畫混淆矩陣的話得要 驗證集中原始影象的標籤與預測到的標籤值。
model.fit下面有對應的evaluate,predict 等,那麼model.fit_generator下面自然也有對應的,且看:
fit_generator
evaluate_generator:
evaluate_generator(generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)
其它很多引數我們暫時用不到,可以寫成(關於validation_generator可以去看我上一篇文章):
model.evaluate_generator(validation_generator,verbose=1)
predict_generator:
predict_generator(generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)
也可以直接寫成:
prediction=model.predict_generator(validation_generator,verbose=1)
接下來我們對prediction做一下處理
#因為prediction是一個n行5列的陣列,我們要把它轉換成一維陣列
#這樣的話每個值就會與驗證集中的標籤值一一對應上了。
predict_label=np.argmax(prediction,axis=1)
驗證集中真實資料的標籤為:
true_label=validation_generator.classes
好了,混淆矩陣中所需要的兩個引數我們都已經得到了。
1:true_label 真實資料標籤
2:predict_label 預測的資料標籤
接下來使用pd.crosstab來簡單畫出混淆矩陣
import pandas as pd
pd.crosstab(true_label,predict_label,rownames=['label'],colnames=['predict'])
下面是操作的注意事項:
我這裡的 validation_generator 是沒有被shuffle的。這樣的話正好與後面的真實標籤跟預測標籤一一對應。
如果前面被shuffle的話,這邊肯定就對不上了。
不知道網友們有沒有什麼更好的方法來把這兩種標籤相互對映?
文章是對著我自己的專案寫的,具體資料這裡沒給出。如果各位看不懂的話可以留言,我把程式碼完整貼上......