KNN實現手寫數字的識別
阿新 • • 發佈:2018-12-21
import numpy as np import matplotlib.pyplot as plt from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import train_test_split from sklearn.externals import joblib # 檢視一下資料集的資料 # zero = plt.imread('./knn_num_data/0/0_1.bmp') # plt.imshow(zero,cmap='gray') # print(zero.shape) # 將資料組合成可以訓練的資料集 path = './knn_num_data/%d/%d_%d.bmp' data = [] target = [] for i in range(10): for j in range(500): im_data = plt.imread(path % (i, i, j + 1)) data.append(im_data) target.append(i) data = np.array(data) # print(data.shape) # knn只能用二維資料 所以更改一下shape data_ = data.reshape(5000, -1) # print(data_.shape) # 分割資料集 選取1%作為測試資料集 X_train, X_test, y_train, y_test = train_test_split(data_, target, test_size=0.01) # 例項化KNN分類器 knn = KNeighborsClassifier() knn.fit(X_train, y_train) # 模型儲存路徑 save_path_name = 'knn_train_model.m' # 儲存模型 joblib.dump(knn, save_path_name) # 載入模型 knn = joblib.load(save_path_name) # 預測結果 y_ = knn.predict(X_test) print(y_) # 訓練集評分 train_score = knn.score(X_train, y_train) print(train_score) # 測試集評分 test_score = knn.score(X_test, y_test) print(test_score)
資料集
連結:https://pan.baidu.com/s/1ehaljfupk-_kuxk3khh3BA 提取碼:zl3o