1. 程式人生 > >用keras實現基本的影象分類任務

用keras實現基本的影象分類任務

資料集介紹

fashion mnist資料集是mnist的進階版本,有10種對應的結果

訓練集有60000個,每一個都是28*28的影象,每一個對應一個標籤(0-9)表示

測試集有10000個

程式碼
import tensorflow as tf
import keras
import numpy as np
import matplotlib.pyplot as plt

#匯入fashioin_mnist資料集
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

#分別於0-9對應
class_names = ['上衣','褲子','套衫','裙子','外套','涼鞋','襯衫','運動鞋','包包','踝靴']

#壓縮畫素值到0-1之間
train_images = train_images / 255.0
test_images = test_images / 255.0

#檢視前幾個資料的影象
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
    
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),   #輸入影象大小為28*28
    keras.layers.Dense(128, activation=tf.nn.relu),  #用relu函式作為啟用函式
    keras.layers.Dense(10, activation=tf.nn.softmax)   #softmax之後輸出10個值,分別表示對應的概率
])

model.compile(optimizer=tf.train.AdamOptimizer(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images,train_labels,epochs= 10)  #執行完準確率有91.13%

test_loss, test_acc = model.evaluate(test_images, test_labels)

print('Test accuracy:', test_acc)   #執行完在測試集上的準確率為88.58%
#測試集的準確率小於訓練集,說明過擬合

參考

https://www.tensorflow.org/tutorials/keras/basic_classification?hl=zh-cn