scikit-learn 支援向量機實現手寫體識別
阿新 • • 發佈:2018-12-09
隨時程式碼,閱讀筆記
%matplotlib inline import matplotlib.pyplot as plt import numpy as np from sklearn import datasets digits = datasets.load_digits() # 載入資料 # 把資料所代表的圖片顯示出來 images_and_labels = list(zip(digits.images, digits.target)) plt.figure(figsize=(8, 6), dpi=200) for index, (image, label) in enumerate(images_and_labels[:8]): plt.subplot(2, 4, index + 1) plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Digit: %i' % label, fontsize=20) print("shape of raw image data: {0}".format(digits.images.shape)) print("shape of data: {0}".format(digits.data.shape)) #shape of raw image data: (1797, 8, 8) #shape of data: (1797, 64) # 把資料分成訓練資料集和測試資料集 from sklearn.cross_validation import train_test_split Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data, digits.target, test_size=0.20, random_state=2); # 使用支援向量機來訓練模型 from sklearn import svm clf = svm.SVC(gamma=0.001, C=100., probability=True) clf.fit(Xtrain, Ytrain); # 評估模型的準確度 from sklearn.metrics import accuracy_score Ypred = clf.predict(Xtest); accuracy_score(Ytest, Ypred) clf.score(Xtest, Ytest) # 檢視預測的情況 fig, axes = plt.subplots(4, 4, figsize=(8, 8)) fig.subplots_adjust(hspace=0.1, wspace=0.1) for i, ax in enumerate(axes.flat): ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest') ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32, transform=ax.transAxes, color='green' if Ypred[i] == Ytest[i] else 'red') ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32, transform=ax.transAxes, color='black') ax.set_xticks([]) ax.set_yticks([]) # Xtest[4] 的各種可能性 clf.predict_proba(Xtest[4].reshape(1, -1)) # 儲存模型引數 from sklearn.externals import joblib joblib.dump(clf, 'digits_svm.pkl'); # 匯入模型引數,直接進行預測 clf = joblib.load('digits_svm.pkl') Ypred = clf.predict(Xtest); clf.score(Xtest, Ytest)
8x8的影象大小,還好,如果影象太大,直接使用畫素值,分類結果並不好,需要降維處理,結合PCA。