170720 混淆矩陣繪製+pandas讀取資料(有點亂,後面抽空再整理)
阿新 • • 發佈:2019-01-07
E:\Backup\validation confusion matrix_final2
# -*- coding: utf-8 -*-
"""
Created on Fri May 19 11:17:12 2017
@author: Bruce Lau
"""
#%%
import itertools
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.metrics import confusion_matrix
from p4_1_ds_competition import idscnn2
# load the data
def reverse_onehot(onehot):
y = []
for i in onehot:
i = i.tolist()
y.append(i.index(max(i)))
return y
#%%
def cal_acc(m1,m2):
a1 = np.array(reverse_onehot(m1))
re = sum(a1==m2)/len(m2)
# print(re)
return re
#%%
def plot_confusion_matrix (cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
# plt.figure(facecolor='w')
im = plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
# plt.colorbar()
plt.colorbar(im,fraction=0.046, pad=0.04)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes) #, rotation=45
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
# print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
cm[i,j]=round(cm[i,j],3)
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# plt.tight_layout()
plt.ylabel('True fault type')
plt.xlabel('CNN prediction fault type')
#%%
def cm_plot(y_,y,name,idx):
class_names=np.array(['0','1','2','3','4','5','6','7','8','9'])
cnf_matrix = confusion_matrix(y_, y)
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
#plt.figure(facecolor='w')
#plot_confusion_matrix(cnf_matrix, classes=class_names,
# title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plt.subplot(1,3,idx)
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=False,
title='Normalized confusion matrix')
plt.title(name)
if idx == 3:
plt.savefig(str(idx)+'.png',dpi=300)
plt.show()
#%%
def papershow(de_c,fe_c):
# load the prediction data 9-17
labels = np.load('labels.npy')
# de_accuracy and fe_accuracy
de_acc = cal_acc(de_c,labels)
fe_acc = cal_acc(fe_c,labels)
# ids fusion process
me = np.ones((2500,10,2))
me[:,:,0]=de_c
me[:,:,1]=fe_c
re = np.ones((2500,10))
for i in range(2500):
stack = me[i,:,:]
re[i]=idscnn2(stack.T)
# ids result
fusion_result = cal_acc(re,labels)
return np.array([de_acc, fe_acc, fusion_result]), re
#%%
def save_cm(path1,path2):
de = np.load(path1)
fe = np.load(path2)
acc, me = papershow(de,fe)
y1_ = de
y2_ = fe
y3_ = me
y4_ = np.load('labels.npy')
y1_ = reverse_onehot(y1_)
y2_ = reverse_onehot(y2_)
y3_ = reverse_onehot(y3_)
y4_ = y4_
#
plt.figure(facecolor='w',figsize=(16,4))
cm_plot(y4_,y1_,'# 7 CNN model @ drive end\n accuracy=83.8%',1)
cm_plot(y4_,y2_,'# 25 CNN model @ fan end\n accuracy=79.2%',2)
cm_plot(y4_,y3_,'Fused model for # 7 and # 25\n accuracy=92.4%',3)
#109
#112
#%%
accuracy = np.ones((20,3))
for i in np.arange(1,21):
print(i)
de = np.load('cnn_pre_pro2/pre_pro_'+str(i)+'/CA/107_de.npy')
fe = np.load('cnn_pre_pro2/pre_pro_'+str(i)+'/CA/109_fe.npy')
acc,re = papershow(de,fe)
accuracy[i-1,:] = acc
#%%
path1 = 'cnn_pre_pro2/pre_pro_1//CA/107_de.npy'
path2 = 'cnn_pre_pro2/pre_pro_1//CA/109_fe.npy'
#
save_cm(path1,path2)
#%% t-test
t_de = accuracy[:,0]
t_fe = accuracy[:,1]
t_me = accuracy[:,2]
t_de_me = stats.ttest_ind(t_de,t_me,equal_var=False)
t_fe_me = stats.ttest_ind(t_fe,t_me,equal_var=False)
print('t-test significant difference between de and me is: %f'%t_de_me[1])
print('t-test significant difference between fe and me is: %f'%t_de_me[1])
print("The averages are: ",np.mean(accuracy,axis=0))
avg = np.mean(accuracy,axis=0)
#%%
import pandas as pd
data = pd.read_excel('statistical-analysis2.xlsx',sheetname=1,skiprows=1)
data_array = data.values
print("The mean values of DS and IDS are \n", data.mean())
print('\n')
print("The std values of DS and IDS are \n", data.std())