使用cnn網路訓練手寫數字資料集
阿新 • • 發佈:2020-12-24
使用cnn網路訓練手寫數字資料集
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/12/23 20:27
# @Author : yang
# @Site :
# @File : mnist2.py
# @Software: PyCharm
# 試了一下用cnn網路直接訓練手寫數字資料集,即不需要將二維圖片一維化,其他的貌似都沒變
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D , MaxPool2D , Flatten , Dropout , BatchNormalization
num_categories = 10
(x_train, y_train),(x_test,y_test) = mnist.load_data()
y_train = keras.utils.to_categorical(y_train, num_categories) #將標籤進行one-hot編碼,僅僅適用於分類較小
y_test = keras.utils.to_categorical(y_test, num_categories)
x_train = x_train / 255 #將資料進行歸一化處理,提高效率
x_test = x_test / 255
x_train = x_train.reshape(-1,28,28,1) #由於cnn是一個使用二維資料的模型,因此將一維資料還原為二維。
x_test = x_test.reshape(-1,28,28,1 )
#-------------------------------cnn網路構建-------------------------------
#該cnn網路採用的是3*3的小卷積核,填充模式是same填充,啟用函式是relu
model = Sequential()
model.add(Conv2D(75 , (3,3) , strides = 1 , padding = 'same' , activation = 'relu' , input_shape = (28,28,1))) #輸入層
model.add(BatchNormalization())
model.add(MaxPool2D((2,2) , strides = 2 , padding = 'same')) #以下為卷積層
model.add(Conv2D(50 , (3,3) , strides = 1 , padding = 'same' , activation = 'relu'))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(MaxPool2D((2,2) , strides = 2 , padding = 'same'))
model.add(Conv2D(25 , (3,3) , strides = 1 , padding = 'same' , activation = 'relu'))
model.add(BatchNormalization())
model.add(MaxPool2D((2,2) , strides = 2 , padding = 'same'))
model.add(Flatten()) #全連線層
model.add(Dense(units = 512 , activation = 'relu'))
model.add(Dropout(0.3))
model.add(Dense(units = num_categories , activation = 'softmax'))
# model.summary()
#-------------------------------模型編譯-------------------------------
model.compile(loss = 'categorical_crossentropy' , metrics = ['accuracy'])
#-------------------------------訓練-------------------------------
model.fit(x_train, y_train,
epochs=5,
verbose=1,
validation_data=(x_test, y_test))
訓練五輪後貌似比三層的效果要好