1. 程式人生 > 實用技巧 >從0開始的機器學習——knn演算法篇(4)

從0開始的機器學習——knn演算法篇(4)

本次實驗採用另一個數據集——手寫字母資料集

首先引入必要的庫:

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn import datasets

digits = datasets.load_digits()

digits.keys()

print(digits.DESCR) //看一下這個資料集的描述

.. _digits_dataset:

Optical recognition of handwritten digits dataset
--------------------------------------------------

**Data Set Characteristics:**

    :Number of Instances: 5620
    :Number of Attributes: 64
    :Attribute Information: 8x8 image of integer pixels in the range 0..16.
    :Missing Attribute Values: None
    :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)
    :Date: July; 1998

This is a copy of the test set of the UCI ML hand-written digits datasets
http://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
The data set contains images of hand-written digits: 10 classes where each class refers to a digit. Preprocessing programs made available by NIST were used to extract normalized bitmaps of handwritten digits from a preprinted form. From a total of 43 people, 30 contributed to the training set and different 13 to the test set. 32x32 bitmaps are divided into nonoverlapping blocks of 4x4 and the number of on pixels are counted in each block. This generates an input matrix of 8x8 where each element is an integer in the range 0..16. This reduces dimensionality and gives invariance to small distortions. For info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G. T. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C. L. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469, 1994. .. topic:: References - C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their Applications to Handwritten Digit Recognition, MSc Thesis, Institute of Graduate Studies in Science and Engineering, Bogazici University. - E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika. - Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin. Linear dimensionalityreduction using relevance weighted LDA. School of Electrical and Electronic Engineering Nanyang Technological University. 2005. - Claudio Gentile. A New Approximate Maximal Margin Classification Algorithm. NIPS. 2000.

X = digits.data
X.shape //這個資料集是簡化的資料集,所以並沒有5620個數據,有1797個數據 每個資料有64個屬性,是一個8x8的矩陣

檢視一下前100個數據的屬性:

可以發現這個資料集和鳶尾花的資料集分佈不一樣,這個是沒有規律的。

隨意選一個數據看一下:

基本看出來是一個數字 8

接下來呼叫封裝好的knn演算法來測試一下:

from sklearn.model_selection import train_test_split //引入分割資料集的方法

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2) //分割資料集

from sklearn.neighbors import KNeighborsClassifier //引入KNN演算法

my_knn_clf = KNeighborsClassifier(n_neighbors=3) //k值為3

my_knn_clf.fit(X_train,y_train) //傳入訓練樣本集

y_predict = my_knn_clf.predict(X_test)//獲得預測樣本資料

y_predict

sum(y_predict == y_test) / len(y_test) # y_predict向量與y_test向量進行比較,如果對應的數值相等,就返回true值,用sum()統計true值的個數,然後比上所有的測試數值個數,就可以獲得預測的精確度

如果不想寫這個邏輯,可以直接呼叫sklearn庫中的方法:

from sklearn.metrics import accuracy_score

accuracy_score(y_test,y_predict)

my_knn_clf.fit(X_test,y_test)

my_knn_clf.score(X_test,y_test)