1. 程式人生 > 其它 >使用cnn網路訓練手寫數字資料集

使用cnn網路訓練手寫數字資料集

使用cnn網路訓練手寫數字資料集

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/12/23 20:27
# @Author : yang
# @Site : 
# @File : mnist2.py
# @Software: PyCharm
# 試了一下用cnn網路直接訓練手寫數字資料集,即不需要將二維圖片一維化,其他的貌似都沒變

import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
import tensorflow.keras as
keras from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Conv2D , MaxPool2D , Flatten , Dropout , BatchNormalization num_categories = 10 (x_train, y_train),(x_test,y_test) = mnist.load_data() y_train = keras.utils.to_categorical(y_train, num_categories) #將標籤進行one-hot編碼,僅僅適用於分類較小
y_test = keras.utils.to_categorical(y_test, num_categories) x_train = x_train / 255 #將資料進行歸一化處理,提高效率 x_test = x_test / 255 x_train = x_train.reshape(-1,28,28,1) #由於cnn是一個使用二維資料的模型,因此將一維資料還原為二維。 x_test = x_test.reshape(-1,28,28,1
) #-------------------------------cnn網路構建------------------------------- #該cnn網路採用的是3*3的小卷積核,填充模式是same填充,啟用函式是relu model = Sequential() model.add(Conv2D(75 , (3,3) , strides = 1 , padding = 'same' , activation = 'relu' , input_shape = (28,28,1))) #輸入層 model.add(BatchNormalization()) model.add(MaxPool2D((2,2) , strides = 2 , padding = 'same')) #以下為卷積層 model.add(Conv2D(50 , (3,3) , strides = 1 , padding = 'same' , activation = 'relu')) model.add(Dropout(0.2)) model.add(BatchNormalization()) model.add(MaxPool2D((2,2) , strides = 2 , padding = 'same')) model.add(Conv2D(25 , (3,3) , strides = 1 , padding = 'same' , activation = 'relu')) model.add(BatchNormalization()) model.add(MaxPool2D((2,2) , strides = 2 , padding = 'same')) model.add(Flatten()) #全連線層 model.add(Dense(units = 512 , activation = 'relu')) model.add(Dropout(0.3)) model.add(Dense(units = num_categories , activation = 'softmax')) # model.summary() #-------------------------------模型編譯------------------------------- model.compile(loss = 'categorical_crossentropy' , metrics = ['accuracy']) #-------------------------------訓練------------------------------- model.fit(x_train, y_train, epochs=5, verbose=1, validation_data=(x_test, y_test))

訓練五輪後貌似比三層的效果要好

在這裡插入圖片描述