Sklearn__SVM實現手寫數字識別
阿新 • • 發佈:2019-01-08
1、 資料準備
from sklearn.model_selection import StratifiedShuffleSplit
import pandas as pd
import numpy as np
from sklearn.datasets import fetch_mldata
class Data_need():
def __init__(self, percent, data_name):
self.percent = percent
self.data_name = data_name
def get_data(self):
data_home = r'D:\Python_data\python Data\sklearn'
mnist = fetch_mldata(self.data_name, data_home=data_home)
return mnist['data'], mnist['target']
## 打亂資料集
def random_data(self, x, y):
mnist_train, mnist_test = 0, 0
## 建立DataFrame
data_y = pd.DataFrame(y, columns=['y'])
n = len(x[0])
data_x = pd.DataFrame(x, columns=list(range(n)))
mnist_data = pd.merge(data_x, data_y, right_index=True, left_index=True)
## 分層取樣
split = StratifiedShuffleSplit(n_splits=1, test_size = self.percent, random_state=42)
for train_index, test_index in split.split(mnist_data, mnist_data['y']):
mnist_train = mnist_data.loc[train_index,:]
mnist_test = mnist_data. loc[test_index,:]
return mnist_train, mnist_test
def train_test_data(self, train, test):
# 將畫素資料變為二值變數
return (np.array(train.iloc[:,:-1]) != 0)*1, np.array(train['y']), (np.array(test.iloc[:,:-1])!= 0)*1, np.array(test['y'])
if __name__ == '__main__':
data_need = Data_need(0.3, 'MNIST original')
x, y = data_need.get_data()
train, test = data_need.random_data(x, y)
x_train_in, y_train_in, x_test_in, y_test_in = data_need.train_test_data(train, test)
2、檢視資料及模型訓練
模型採用
ovr
(ova)SMV模型
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
plt.style.use('ggplot')
def to_plot(num, n):
"""
num: 想要繪製的數值
n :第幾個樣本
"""
plt_x_array = x_train_in[y_train_in == num]
some_digit = plt_x_array[n]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=plt.cm.binary, interpolation='nearest')
plt.axis('off')
plt.show()
if __name__ == '__main__' :
to_plot(8, 10)
ova_svm_clf = LinearSVC(loss='hinge', C=5, multi_class='ovr')
ova_svm_clf.fit(x_train_in, y_train_in)
## 交叉驗證出預測
y_prd = cross_val_predict(ova_svm_clf, x_train_in, y_train_in, cv=3)
## 評估 混淆矩陣
conf_m = confusion_matrix(y_train_in, y_prd)
3、模型評估
### 整體的準確率
def clf_correct(y_train, y_prd):
return sum((y_train - y_prd) == 0) / len(y_train)
class plot_conf_m():
def __init__(self, conf_m):
self.conf_m = conf_m
def plt_conf_m(self):
## 用matshow()函式繪製出混淆矩陣
plt.matshow(self.conf_m, cmap=plt.cm.gray)
def plt_error_conf_m(self):
## 關注誤差資料的影象呈現
row_sums = self.conf_m.sum(axis=1, keepdims=True)
norm_conf_m = self.conf_m / row_sums
## 用0 將正確分類覆蓋 檢視那個類分類特別不準
np.fill_diagonal(norm_conf_m, 0)
plt.matshow(norm_conf_m, cmap=plt.cm.gray)
if __name__ == '__main__':
print("整體準確性:{}".format(clf_correct(y_train_in, y_prd)))
plt_confm = plot_conf_m(conf_m)
plt_confm.plt_conf_m(), plt.title("Focus on the correct prediction")
plt_confm.plt_error_conf_m(), plt.title("Focus on the error prediction")
plt.show()
## 整體準確性:0.902795918367347
從下面兩個混淆矩陣中可以看出 錯誤分類分佈比較平均,還待提高,所以增大C
進行重新擬合
4、模型修正及預測
1. 模型修正
if __name__ == '__main__' :
ova_svm_clf_fix = LinearSVC(loss='hinge', C=10, multi_class='ovr')
ova_svm_clf_fix.fit(x_train_in, y_train_in)
## 交叉驗證出預測
y_prd_fix = cross_val_predict(ova_svm_clf_fix, x_train_in, y_train_in, cv=3)
## 評估 混淆矩陣
conf_m_fix = confusion_matrix(y_train_in, y_prd_fix)
print("整體準確性:{}".format(clf_correct(y_train_in, y_prd_fix)))
plt_confm_fix = plot_conf_m(conf_m_fix)
plt_confm_fix.plt_conf_m(), plt.title("Focus on the correct prediction")
plt_confm_fix.plt_error_conf_m(), plt.title("Focus on the error prediction")
plt.show()
## 整體準確率0.91204
增大C
雖然提高了整體的準確率,對準確率並沒有明顯好轉,可見線性核對該資料分類效果不明顯。所以改用高斯核進行擬合。
from sklearn.svm import SVC
from sklearn.metrics import classification_report
if __name__ == '__main__': # ova
ova_svm_clf_rbf = SVC(kernel='rbf',gamma = 'auto', C = 15, cache_size= 8000, decision_function_shape = 'ovr')
ova_svm_clf_rbf.fit(x_train_in, y_train_in)
y_prd_rbf = ova_svm_clf_rbf.predict(x_train_in)
print('整體準確率{}'.format(clf_correct(y_train_in, y_prd_rbf))) # 0.90
conf_m_rbf = confusion_matrix(y_train_in, y_prd_rbf)
plt_confm_rbf = plot_conf_m(conf_m_rbf)
plt_confm_rbf.plt_conf_m(), plt.title("Focus on the correct prediction")
plt_confm_rbf.plt_error_conf_m(), plt.title("Focus on the error prediction")
plt.show()
# 輸出詳細報告
print(classification_report(y_train_in, y_prd_rbf))
"""
# 整體準確率:0.9831632653061224
precision recall f1-score support
0.0 0.99 0.99 0.99 4832
1.0 0.99 0.99 0.99 5514
2.0 0.98 0.99 0.99 4893
3.0 0.98 0.97 0.97 4999
4.0 0.98 0.98 0.98 4777
5.0 0.98 0.98 0.98 4419
6.0 0.99 0.99 0.99 4813
7.0 0.98 0.98 0.98 5105
8.0 0.98 0.98 0.98 4777
9.0 0.98 0.97 0.97 4871
avg / total 0.98 0.98 0.98 49000
"""
高斯核的準確率明顯提升了,但對9和4 與 3和5 的識別還是不是十分精確
2. 模型預測
if __name__ == '__main__' :
y_test_prd = ova_svm_clf_fix.predict(x_test)
print("整體準確性:{}".format(clf_correct(y_train, y_test_prd)))
plt_confm_test = plot_conf_m(conf_m)
plt_confm_test.plt_conf_m(), plt.title("Focus on the correct prediction")
plt_confm_test.plt_error_conf_m(), plt.title("Focus on the error prediction")
plt.show()
# 輸出詳細報告
print(classification_report(y_test_in, y_test_prd))
"""
# 整體準確性:0.9615238095238096
precision recall f1-score support
0.0 0.97 0.99 0.98 2071
1.0 0.97 0.98 0.98 2363
2.0 0.96 0.97 0.96 2097
3.0 0.95 0.95 0.95 2142
4.0 0.96 0.96 0.96 2047
5.0 0.96 0.94 0.95 1894
6.0 0.97 0.98 0.97 2063
7.0 0.97 0.96 0.97 2188
8.0 0.95 0.95 0.95 2048
9.0 0.94 0.94 0.94 2087
avg / total 0.96 0.96 0.96 21000
"""