1. 程式人生 > >使用keras的fit_generator來獲得混淆矩陣Confusion Matrix

使用keras的fit_generator來獲得混淆矩陣Confusion Matrix

還是google過來的方法,說明它還是挺靠譜滴。這裡有必要記錄一下。

關於混亂淆矩陣是用來幹嘛的,這裡就不說了。各位可以百度or谷歌。

關於如何使用fit_generator來進行訓練可以看我上一篇文章。

我們在使用fit_generator方法來進行訓練的時候,是不需要自己讀取x_img_train,y_label_train的。都是generator幫我們做好了。

但是要想畫混淆矩陣的話得要 驗證集中原始影象的標籤與預測到的標籤值。

model.fit下面有對應的evaluate,predict 等,那麼model.fit_generator下面自然也有對應的,且看:

fit_generator

evaluate_generator:

evaluate_generator(generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)

其它很多引數我們暫時用不到,可以寫成(關於validation_generator可以去看我上一篇文章):

model.evaluate_generator(validation_generator,verbose=1)

predict_generator:

predict_generator(generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)

也可以直接寫成:

prediction=model.predict_generator(validation_generator,verbose=1)

接下來我們對prediction做一下處理

#因為prediction是一個n行5列的陣列,我們要把它轉換成一維陣列
#這樣的話每個值就會與驗證集中的標籤值一一對應上了。
predict_label=np.argmax(prediction,axis=1)

驗證集中真實資料的標籤為:

true_label=validation_generator.classes

好了,混淆矩陣中所需要的兩個引數我們都已經得到了。

1:true_label  真實資料標籤

2:predict_label  預測的資料標籤

接下來使用pd.crosstab來簡單畫出混淆矩陣

import pandas as pd
pd.crosstab(true_label,predict_label,rownames=['label'],colnames=['predict'])

 

下面是操作的注意事項:

我這裡的 validation_generator 是沒有被shuffle的。這樣的話正好與後面的真實標籤跟預測標籤一一對應。

如果前面被shuffle的話,這邊肯定就對不上了。

不知道網友們有沒有什麼更好的方法來把這兩種標籤相互對映?

 

文章是對著我自己的專案寫的,具體資料這裡沒給出。如果各位看不懂的話可以留言,我把程式碼完整貼上......