利用python中的matplotlib列印混淆矩陣例項
前面說過混淆矩陣是我們在處理分類問題時,很重要的指標,那麼如何更好的把混淆矩陣給打印出來呢,直接做表或者是前端視覺化,小編曾經就嘗試過用前端(D5)做出來,然後截圖,顯得不那麼好看。。
程式碼:
import itertools import matplotlib.pyplot as plt import numpy as np def plot_confusion_matrix(cm,classes,normalize=False,title='Confusion matrix',cmap=plt.cm.Blues): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:,np.newaxis] print("Normalized confusion matrix") else: print('Confusion matrix,without normalization') print(cm) plt.imshow(cm,interpolation='nearest',cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks,rotation=45) plt.yticks(tick_marks,classes) fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i,j in itertools.product(range(cm.shape[0]),range(cm.shape[1])): plt.text(j,i,format(cm[i,j],fmt),horizontalalignment="center",color="white" if cm[i,j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') plt.show() # plt.savefig('confusion_matrix',dpi=200) cnf_matrix = np.array([ [4101,2,5,24,0],[50,3930,6,14,5],[29,3,3973,4,[45,7,1,3878,119],[31,8,28,3936],]) class_names = ['Buildings','Farmland','Greenbelt','Wasteland','Water'] # plt.figure() # plot_confusion_matrix(cnf_matrix,classes=class_names,# title='Confusion matrix,without normalization') # Plot normalized confusion matrix plt.figure() plot_confusion_matrix(cnf_matrix,normalize=True,title='Normalized confusion matrix')
在放矩陣位置,放一下你的混淆矩陣就可以,當然視覺化混淆矩陣這一步也可以直接在模型執行中完成。
補充知識:混淆矩陣(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)
原理
在機器學習中,混淆矩陣是一個誤差矩陣,常用來視覺化地評估監督學習演算法的效能. 混淆矩陣大小為 (n_classes,n_classes) 的方陣,其中 n_classes 表示類的數量. 這個矩陣的每一行表示真實類中的例項,而每一列表示預測類中的例項 (Tensorflow 和 scikit-learn 採用的實現方式). 也可以是,每一行表示預測類中的例項,而每一列表示真實類中的例項 (Confusion matrix From Wikipedia 中的定義). 通過混淆矩陣,可以很容易看出系統是否會弄混兩個類,這也是混淆矩陣名字的由來.
混淆矩陣是一種特殊型別的列聯表(contingency table)或交叉製表(cross tabulation or crosstab). 其有兩維 (真實值 "actual" 和 預測值 "predicted" ),這兩維都具有相同的類("classes")的集合. 在列聯表中,每個維度和類的組合是一個變數. 列聯表以表的形式,視覺化地表示多個變數的頻率分佈.
使用混淆矩陣( scikit-learn 和 Tensorflow)
下面先介紹在 scikit-learn 和 tensorflow 中計算混淆矩陣的 API (Application Programming Interface) 介面函式,然後在一個示例中,使用這兩個 API 函式.
scikit-learn 混淆矩陣函式 sklearn.metrics.confusion_matrix API 介面
skearn.metrics.confusion_matrix( y_true,# array,Gound true (correct) target values y_pred,Estimated targets as returned by a classifier labels=None,List of labels to index the matrix. sample_weight=None # array-like of shape = [n_samples],Optional sample weights )
在 scikit-learn 中,計算混淆矩陣用來評估分類的準確度.
按照定義,混淆矩陣 C 中的元素 Ci,j 等於真實值為組 i,而預測為組 j 的觀測數(the number of observations). 所以對於二分類任務,預測結果中,正確的負例數(true negatives,TN)為 C0,0; 錯誤的負例數(false negatives,FN)為 C1,0; 真實的正例數為 C1,1; 錯誤的正例數為 C0,1.
如果 labels 為 None,scikit-learn 會把在出現在 y_true 或 y_pred 中的所有值新增到標記列表 labels 中,並排好序.
Tensorflow 混淆矩陣函式 tf.confusion_matrix API 介面
tf.confusion_matrix( labels,# 1-D Tensor of real labels for the classification task predictions,# 1-D Tensor of predictions for a givenclassification num_classes=None,# The possible number of labels the classification task can have dtype=tf.int32,# Data type of the confusion matrix name=None,# Scope name weights=None,# An optional Tensor whose shape matches predictions )
Tensorflow tf.confusion_matrix 中的 num_classes 引數的含義,與 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 引數相近,是與標記有關的引數,表示類的總個數,但沒有列出具體的標記值. 在 Tensorflow 中一般是以整數作為標記,如果標記為字串等非整數型別,則需先轉為整數表示. 如果 num_classes 引數為 None,則把 labels 和 predictions 中的最大值 + 1,作為num_classes 引數值.
tf.confusion_matrix 的 weights 引數和 sklearn.metrics.confusion_matrix 的 sample_weight 引數的含義相同,都是對預測值進行加權,在此基礎上,計算混淆矩陣單元的值.
使用示例
#!/usr/bin/env python # -*- coding: utf8 -*- """ Author: klchang Description: A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix. Date: 2018.9.8 """ from __future__ import print_function import tensorflow as tf import sklearn.metrics y_true = [1,4] y_pred = [2,4] # Build graph with tf.confusion_matrix operation sess = tf.InteractiveSession() op = tf.confusion_matrix(y_true,y_pred) op2 = tf.confusion_matrix(y_true,y_pred,num_classes=6,dtype=tf.float32,weights=tf.constant([0.3,0.4,0.3])) # Execute the graph print ("confusion matrix in tensorflow: ") print ("1. default: \n",op.eval()) print ("2. customed: \n",sess.run(op2)) sess.close() # Use sklearn.metrics.confusion_matrix function print ("\nconfusion matrix in scikit-learn: ") print ("1. default: \n",sklearn.metrics.confusion_matrix(y_true,y_pred)) print ("2. customed: \n",labels=range(6),sample_weight=[0.3,0.3]))
以上這篇利用python中的matplotlib列印混淆矩陣例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。