1. 程式人生 > >170720 混淆矩陣繪製+pandas讀取資料(有點亂,後面抽空再整理)

170720 混淆矩陣繪製+pandas讀取資料(有點亂,後面抽空再整理)

E:\Backup\validation confusion matrix_final2

# -*- coding: utf-8 -*-
"""
Created on Fri May 19 11:17:12 2017

@author: Bruce Lau
"""
#%%
import itertools
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

from sklearn.metrics import confusion_matrix
from p4_1_ds_competition import
idscnn2 # load the data def reverse_onehot(onehot): y = [] for i in onehot: i = i.tolist() y.append(i.index(max(i))) return y #%% def cal_acc(m1,m2): a1 = np.array(reverse_onehot(m1)) re = sum(a1==m2)/len(m2) # print(re) return re #%% def plot_confusion_matrix
(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
""" This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ # plt.figure(facecolor='w')
im = plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) # plt.colorbar() plt.colorbar(im,fraction=0.046, pad=0.04) tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes) #, rotation=45 plt.yticks(tick_marks, classes) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print('Confusion matrix, without normalization') # print(cm) thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): cm[i,j]=round(cm[i,j],3) plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") # plt.tight_layout() plt.ylabel('True fault type') plt.xlabel('CNN prediction fault type') #%% def cm_plot(y_,y,name,idx): class_names=np.array(['0','1','2','3','4','5','6','7','8','9']) cnf_matrix = confusion_matrix(y_, y) np.set_printoptions(precision=2) # Plot non-normalized confusion matrix #plt.figure(facecolor='w') #plot_confusion_matrix(cnf_matrix, classes=class_names, # title='Confusion matrix, without normalization') # Plot normalized confusion matrix plt.subplot(1,3,idx) plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=False, title='Normalized confusion matrix') plt.title(name) if idx == 3: plt.savefig(str(idx)+'.png',dpi=300) plt.show() #%% def papershow(de_c,fe_c): # load the prediction data 9-17 labels = np.load('labels.npy') # de_accuracy and fe_accuracy de_acc = cal_acc(de_c,labels) fe_acc = cal_acc(fe_c,labels) # ids fusion process me = np.ones((2500,10,2)) me[:,:,0]=de_c me[:,:,1]=fe_c re = np.ones((2500,10)) for i in range(2500): stack = me[i,:,:] re[i]=idscnn2(stack.T) # ids result fusion_result = cal_acc(re,labels) return np.array([de_acc, fe_acc, fusion_result]), re #%% def save_cm(path1,path2): de = np.load(path1) fe = np.load(path2) acc, me = papershow(de,fe) y1_ = de y2_ = fe y3_ = me y4_ = np.load('labels.npy') y1_ = reverse_onehot(y1_) y2_ = reverse_onehot(y2_) y3_ = reverse_onehot(y3_) y4_ = y4_ # plt.figure(facecolor='w',figsize=(16,4)) cm_plot(y4_,y1_,'# 7 CNN model @ drive end\n accuracy=83.8%',1) cm_plot(y4_,y2_,'# 25 CNN model @ fan end\n accuracy=79.2%',2) cm_plot(y4_,y3_,'Fused model for # 7 and # 25\n accuracy=92.4%',3) #109 #112 #%% accuracy = np.ones((20,3)) for i in np.arange(1,21): print(i) de = np.load('cnn_pre_pro2/pre_pro_'+str(i)+'/CA/107_de.npy') fe = np.load('cnn_pre_pro2/pre_pro_'+str(i)+'/CA/109_fe.npy') acc,re = papershow(de,fe) accuracy[i-1,:] = acc #%% path1 = 'cnn_pre_pro2/pre_pro_1//CA/107_de.npy' path2 = 'cnn_pre_pro2/pre_pro_1//CA/109_fe.npy' # save_cm(path1,path2) #%% t-test t_de = accuracy[:,0] t_fe = accuracy[:,1] t_me = accuracy[:,2] t_de_me = stats.ttest_ind(t_de,t_me,equal_var=False) t_fe_me = stats.ttest_ind(t_fe,t_me,equal_var=False) print('t-test significant difference between de and me is: %f'%t_de_me[1]) print('t-test significant difference between fe and me is: %f'%t_de_me[1]) print("The averages are: ",np.mean(accuracy,axis=0)) avg = np.mean(accuracy,axis=0) #%% import pandas as pd data = pd.read_excel('statistical-analysis2.xlsx',sheetname=1,skiprows=1) data_array = data.values print("The mean values of DS and IDS are \n", data.mean()) print('\n') print("The std values of DS and IDS are \n", data.std())