1. 程式人生 > >KNN 演算法-實戰篇-如何識別手寫數字

KNN 演算法-實戰篇-如何識別手寫數字

> **公號:碼農充電站pro** > **主頁:** 上篇文章介紹了[KNN 演算法的原理](https://www.cnblogs.com/codeshell/p/14072586.html),今天來介紹如何**使用KNN 演算法識別手寫數字**? ### 1,手寫數字資料集 手寫數字資料集是一個用於影象處理的資料集,這些資料描繪了 **[0, 9]** 的數字,我們可以用**KNN 演算法**來識別這些數字。 [MNIST](http://yann.lecun.com/exdb/mnist/) 是完整的手寫數字資料集,其中包含了60000 個訓練樣本和10000 個測試樣本。 **sklearn** 中也有一個自帶的[手寫數字資料集](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/datasets/data/digits.csv.gz): - 共包含 1797 個數據樣本,每個樣本描繪了一個 **8*8** 畫素的 **[0, 9]** 的數字。 - 每個樣本由 65 個數字組成: - 前 64 個數字是特徵資料,特徵資料的範圍是 **[0, 16]** - 最後一個數字是目標資料,目標資料的範圍是 **[0, 9]** 我們抽出 5 個樣本來看下: ```python 0,0,5,13,9,1,0,0,0,0,13,15,10,15,5,0,0,3,15,2,0,11,8,0,0,4,12,0,0,8,8,0,0,5,8,0,0,9,8,0,0,4,11,0,1,12,7,0,0,2,14,5,10,12,0,0,0,0,6,13,10,0,0,0,0 0,0,0,12,13,5,0,0,0,0,0,11,16,9,0,0,0,0,3,15,16,6,0,0,0,7,15,16,16,2,0,0,0,0,1,16,16,3,0,0,0,0,1,16,16,6,0,0,0,0,1,16,16,6,0,0,0,0,0,11,16,10,0,0,1 0,0,0,4,15,12,0,0,0,0,3,16,15,14,0,0,0,0,8,13,8,16,0,0,0,0,1,6,15,11,0,0,0,1,8,13,15,1,0,0,0,9,16,16,5,0,0,0,0,3,13,16,16,11,5,0,0,0,0,3,11,16,9,0,2 0,0,7,15,13,1,0,0,0,8,13,6,15,4,0,0,0,2,1,13,13,0,0,0,0,0,2,15,11,1,0,0,0,0,0,1,12,12,1,0,0,0,0,0,1,10,8,0,0,0,8,4,5,14,9,0,0,0,7,13,13,9,0,0,3 0,0,0,1,11,0,0,0,0,0,0,7,8,0,0,0,0,0,1,13,6,2,2,0,0,0,7,15,0,9,8,0,0,5,16,10,0,16,6,0,0,4,15,16,13,16,1,0,0,0,0,3,15,10,0,0,0,0,0,2,16,4,0,0,4 ``` 使用該資料集,需要先載入: ```python >
>> from sklearn.datasets import load_digits >>> digits = load_digits() ``` 檢視第一個影象資料: ```python >>> digits.images[0] array([[ 0., 0., 5., 13., 9., 1., 0., 0.], [ 0., 0., 13., 15., 10., 15., 5., 0.], [ 0., 3., 15., 2., 0., 11., 8., 0.], [ 0., 4., 12., 0., 0., 8., 8., 0.], [ 0., 5., 8., 0., 0., 9., 8., 0.], [ 0., 4., 11., 0., 1., 12., 7., 0.], [ 0., 2., 14., 5., 10., 12., 0., 0.], [ 0., 0., 6., 13., 10., 0., 0., 0.]]) ``` 我們可以用 [matplotlib](https://matplotlib.org/) 將該影象畫出來: ```python >
>> import matplotlib.pyplot as plt >>> plt.imshow(digits.images[0]) >>> plt.show() ``` 畫出來的影象如下,代表 **0**: ![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201130132714324.png?) ### 2,sklearn 對 KNN 演算法的實現 **sklearn** 庫的 [neighbors](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.neighbors) 模組實現了**KNN** 相關演算法,其中: - `KNeighborsClassifier` 類用於分類問題 - `KNeighborsRegressor` 類用於迴歸問題 這兩個類的構造方法基本一致,這裡我們主要介紹 `KNeighborsClassifier` 類,原型如下: ```python KNeighborsClassifier( n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None, **kwargs) ``` 來看下幾個重要引數的含義: - **n_neighbors**:即 **KNN** 中的 K 值,一般使用預設值 5。 - **weights**:用於確定鄰居的權重,有三種方式: - weights=uniform,表示所有鄰居的權重相同。 - weights=distance,表示權重是距離的倒數,即與距離成反比。 - 自定義函式,可以自定義不同距離所對應的權重,一般不需要自己定義函式。 - **algorithm**:用於設定計算鄰居的演算法,它有四種方式: - algorithm=auto,根據資料的情況自動選擇適合的演算法。 - algorithm=kd_tree,使用 [KD 樹](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html) 演算法。 - **KD 樹**是一種多維空間的資料結構,方便對資料進行檢索。 - **KD 樹**適用於維度較少的情況,一般維數不超過 20,如果維數大於 20 之後,效率會下降。 - algorithm=ball_tree,使用[球樹](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.BallTree.html)演算法。 - 與**KD 樹**一樣都是多維空間的資料結構。 - **球樹**更適用於維度較大的情況。 - algorithm=brute,稱為**暴力搜尋**。 - 它和 **KD 樹**相比,採用的是線性掃描,而不是通過構造樹結構進行快速檢索。 - 缺點是,當訓練集較大的時候,效率很低。 - **leaf_size**:表示構造 **KD 樹**或**球樹**時的**葉子節點數**,預設是 30。 - 調整 leaf_size 會影響樹的構造和搜尋速度。 ### 3,構造 KNN 分類器 首先載入資料集: ```python from sklearn.datasets import load_digits digits = load_digits() data = digits.data # 特徵集 target = digits.target # 目標集 ``` 將資料集拆分為**訓練集**(75%)和**測試集**(25%): ```python from sklearn.model_selection import train_test_split train_x, test_x, train_y, test_y = train_test_split( data, target, test_size=0.25, random_state=33) ``` 構造**KNN** 分類器: ```python from sklearn.neighbors import KNeighborsClassifier # 採用預設引數 knn = KNeighborsClassifier() ``` 擬合模型: ```python knn.fit(train_x, train_y) ``` 預測資料: ```python predict_y = knn.predict(test_x) ``` 計算模型準確度: ```python from sklearn.metrics import accuracy_score score = accuracy_score(test_y, predict_y) print score # 0.98 ``` 最終計算出來模型的準確度是 **98%**,準確度還是不錯的。 ### 4,總結 本篇文章使用**KNN 演算法**處理了一個實際的分類問題,主要介紹了以下幾點: - 介紹了**sklearn** 中自帶的手寫數字集,並用 **matplotlib** 模組畫出了數字影象。 - 介紹了**sklearn** 中 `neighbors.KNeighborsClassifier` 類的用法。 - 使用 `KNeighborsClassifier` 來識別手寫數字。 (本節完。) --- **推薦閱讀:** [***KNN 演算法-理論篇-如何給電影進行分類***](https://www.cnblogs.com/codeshell/p/14072586.html) [***決策樹演算法-理論篇-如何計算資訊純度***](https://www.cnblogs.com/codeshell/p/13984334.html) [***決策樹演算法-實戰篇-鳶尾花及波士頓房價預測***](https://www.cnblogs.com/codeshell/p/13984334.html) [***樸素貝葉斯分類-理論篇-如何通過概率解決分類問題***](https://www.cnblogs.com/codeshell/p/13999440.html) [***樸素貝葉斯分類-實戰篇-如何進行文字分類***](https://www.cnblogs.com/codeshell/p/14034097.html) --- 歡迎關注作者公眾號,獲取更多技術乾貨。 ![碼農充電站pro](https://img-blog.csdnimg.cn/20200505082843773.png?#pic