1. 程式人生 > 實用技巧 >機器學習——用卷積神經網路(CNN)實現手寫數字識別

機器學習——用卷積神經網路(CNN)實現手寫數字識別

原文連結:https://data-flair.training/blogs/python-deep-learning-project-handwritten-digit-recognition/

原文講得很詳細,這裡補充一些註釋。由於直接從庫匯入mnist資料集需要的時間非常久,因此這裡匯入的是本地已下載好的mnist資料集。(但我懷疑我下了假的資料集,咋驗證準確率這麼低,所以這裡不提供了)

import keras
from keras import backend as K
import numpy as np
from keras.datasets import mnist
from keras.models import
Sequential from keras.layers import Dense, Dropout, Flatten from keras.layers import Conv2D, MaxPooling2D batch_size = 128 #一次訓練所選取的樣本數 num_classes = 10 #分類個數 epochs = 10 #訓練輪數 #讀取已下載到本地的資料集 f=np.load('C:/Users/Administrator/.keras/datasets/mnist.npz') x_train,y_train=f['x_train'],f['y_train
'] x_test,y_test=f['x_test'],f['y_test'] #print(x_train.shape, y_train.shape) #資料預處理 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) x_test = x_test.reshape(x_test.shape[0], 28, 28, 1) input_shape = (28, 28, 1) x_train = x_train.astype('float32') #轉換資料型別 x_test = x_test.astype('float32') x_train
/= 255 #歸一化 x_test /= 255 y_train = keras.utils.to_categorical(y_train, num_classes) #將整形陣列轉化為二元型別矩陣 y_test = keras.utils.to_categorical(y_test, num_classes) #print('x_train shape:', x_train.shape) #print(x_train.shape[0], 'train samples') #print(x_test.shape[0], 'test samples') #建立CNN模型 model = Sequential() #這裡採用順序模型構建CNN #輸入層,這裡指定輸入資料形狀為28*28*1 卷積核數量為32 形狀為3*3 model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=input_shape)) #新增中間層 model.add(Conv2D(64, (3, 3), activation='relu')) #卷積層 model.add(MaxPooling2D(pool_size=(2, 2))) #最大池化層 model.add(Dropout(0.25)) #通過Dropout防止過擬合 model.add(Flatten()) #展平層 model.add(Dense(256, activation='relu')) #全連線層 model.add(Dropout(0.5)) model.add(Dense(num_classes, activation='softmax')) #損失函式 model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adadelta(),metrics=['accuracy']) #訓練模型 hist = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=2,validation_data=(x_test, y_test)) print("模型訓練完成") #模型評估 score = model.evaluate(x_test, y_test, verbose=0) print('test loss: ', score[0]) print('test accuracy: ', score[1])