eclipse擼一發Keras卷積神經網路對手寫數字識別
一、導讀
1、window10 python環境Anaconda 安裝
2、keras安裝
3、tensorflow安裝
4、eclipse python開發外掛PyDev安裝,配置
5、keras卷積神經網路對手寫數字識別
二、環境安裝
1、Anaconda
Anaconda指的是一個開源的Python發行版本,其包含了conda、Python等180多個科學包及其依賴項。功能非常齊全。
下載地址:https://www.anaconda.com/distribution/#download-section
下載安裝
安裝成功後,配置環境變數,開啟cmd,用命令python --version檢視版本號
其中,環境變數非常重要,ssl的環境變數必須配置,否則使用pip時會出現異常(pip is configured with locations that require TLS/SSL, however the ssl module in Python is not available. )
特別說明:安裝之後,可能有些包是有問題的,例如:matplotlib和Pillow沒法正常使用,會出現如下異常
File "E:\Program Files\Anaconda3\lib\site-packages\matplotlib\pyplot.py", line 32, in <module> import matplotlib.colorbar File "E:\Program Files\Anaconda3\lib\site-packages\matplotlib\colorbar.py", line 32, in <module> import matplotlib.contour as contour File "E:\Program Files\Anaconda3\lib\site-packages\matplotlib\contour.py", line 18, in <module> import matplotlib.font_manager as font_manager File "E:\Program Files\Anaconda3\lib\site-packages\matplotlib\font_manager.py", line 48, in <module> from matplotlib import afm, cbook, ft2font, rcParams, get_cachedir ImportError: DLL load failed: 找不到指定的模組。
這時候,跟到異常的程式碼行,看哪些庫是沒有正常依賴進來的,重新安裝即可。可以類推如下的操作,來進行
pip uninstall matplotlib
pip install matplotlib
pip uninstall Pillow
pip install Pillow
2、keras安裝
使用如下命令安裝:
pip install keras
3、安裝tensorflow
由於keras本身不提供執行,它依賴於其他的運算引擎,例如TensorFlow、Theano、CNTK等,這裡,我們安裝TensorFlow作為keras的backend。
用如下命令安裝:
Tensorflow有cpu版和gpu版
cpu版:pip install --user --ignore-installed --upgrade tensorflow
gpu版:pip install --user --ignore-installed --upgrade tensorflow-gpu
這裡,安裝的是cpu版,--user引數是因為有許可權問題,所以加上
4、eclipse PyDev外掛安裝
(1)、開啟eclipse,Help->Eclipse Marketplace,搜尋選擇PyDev,安裝
(2)、配置eclipse開發環境
window->preferences->PyDev->interpreters->Python Interpreter
配置完成點確定即可,便可以在eclipse上開發python程式了,和開發java一樣,非常方便。
三、Keras卷積神經網路進行mnist手寫數字識別
1、先載入資料集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
說明: mnist.load_data()如果本地沒有,在會去下載資料集
列印一下x_train的shape,看看資料集有哪些東西
print(x_train.shape)
2、展示樣本的圖片
#將畫板分為1行2列,本幅圖位於第一個位置
plt.subplot(1,2,1)
# 列印兩個個樣本
plt.imshow(x_train[5], cmap='gray')
plt.subplot(1,2,2)
plt.imshow(x_train[10], cmap='gray')
plt.show()
3、疊一個CNN網路
model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(img_x, img_y, 1)))
model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, kernel_size=(5, 5), activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
和之前用dl4j疊CNN一樣,《有趣的卷積神經網路》
4、完整的程式碼
from keras.datasets import mnist
from keras.layers import Conv2D, MaxPool2D
from keras.layers import Dense, Flatten
from keras.models import Sequential
from keras.utils import to_categorical
import matplotlib.pyplot as plt
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape)
#將畫板分為1行2列,本幅圖位於第一個位置
plt.subplot(1,2,1)
# 列印兩個個樣本
plt.imshow(x_train[5], cmap='gray')
plt.subplot(1,2,2)
plt.imshow(x_train[10], cmap='gray')
plt.show()
img_x, img_y = 28, 28
x_train = x_train.reshape(x_train.shape[0], img_x, img_y, 1)
x_test = x_test.reshape(x_test.shape[0], img_x, img_y, 1)
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(img_x, img_y, 1)))
model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, kernel_size=(5, 5), activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=128, epochs=1)
score = model.evaluate(x_test, y_test)
print('accuracy', score[1])
5、執行結果
eclipse run as Python Run,迭代1個批次的結果如下。(當然,在eclipse中也可以debug除錯python程式)
<