[keras實戰] 小型CNN實現Cifar-10資料集84%準確率
阿新 • • 發佈:2019-02-14
實驗環境
程式碼基於python2.7, Keras1(部分介面在Keras2中已經被修改,如果你使用的是Keras2請查閱文件修改介面)
個人使用的是蟲資料提供的免費GPU主機,GTX1080顯示卡,因為是免費賬號,所以視訊記憶體最高只有1G。為了防止超視訊記憶體程序被kill,我在開頭設定了佔用GPU的最大視訊記憶體大小,如果在這方面沒有限制,可以註釋掉這些程式碼。整個網路引數約160萬,個人測試的結果一個epoch大概需要20s,最終accuracy在84%左右。
程式碼
# import the modules we need
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.convolutional import AveragePooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dense
from keras.layers.core import Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.datasets import cifar10
from keras.utils.np_utils import to_categorical
from keras import metrics
from keras.optimizers import SGD,RMSprop,Adam
from keras.callbacks import EarlyStopping
from keras.backend.tensorflow_backend import set_session
import tensorflow as tf
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.09
set_session(tf.Session(config=config))
#define the Sequential model
class CNNNet:
@staticmethod
def createNet(input_shapes,nb_class):
feature_layers = [
BatchNormalization(input_shape=input_shapes),
Conv2D(64,3,3,border_mode="same"),
Activation("relu"),
BatchNormalization(),
Conv2D(64,3,3,border_mode="same"),
Activation("relu"),
MaxPooling2D(pool_size=(2,2),strides=(2,2)),
BatchNormalization(),
Conv2D(128,3,3,border_mode="same"),
Activation("relu"),
BatchNormalization(),
Dropout(0.5),
Conv2D(128,3,3,border_mode="same"),
Activation("relu"),
MaxPooling2D(pool_size=(2,2),strides=(2,2)),
BatchNormalization(),
Dropout(0.5),
Conv2D(128,3,3,border_mode="same"),
Activation("relu"),
Dropout(0.5),
Conv2D(128,3,3,border_mode="same"),
Activation("relu"),
MaxPooling2D(pool_size=(2,2),strides=(2,2)),
BatchNormalization()
]
classification_layer=[
Flatten(),
Dense(512),
Activation("relu"),
Dropout(0.5),
Dense(nb_class),
Activation("softmax")
]
model = Sequential(feature_layers+classification_layer)
return model
#parameters
NB_EPOCH = 40
BATCH_SIZE = 128
VERBOSE = 1
VALIDATION_SPLIT = 0.2
IMG_ROWS=32
IMG_COLS = 32
NB_CLASSES = 10
INPUT_SHAPE =(IMG_ROWS,IMG_COLS,3)
#load cifar-10 dataset
(X_train,Y_train),(X_test,Y_test) = cifar10.load_data()
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
X_train = X_train.reshape(X_train.shape[0],IMG_ROWS,IMG_COLS,3)
X_test = X_test.reshape(X_test.shape[0],IMG_ROWS,IMG_COLS,3)
print(X_train.shape[0],"train samples")
print(Y_test.shape[0],"test samples")
#convert class vectors to binary class matrices
Y_train = to_categorical(Y_train,NB_CLASSES)
Y_test = to_categorical(Y_test,NB_CLASSES)
# init the optimizer and model
model = CNNNet.createNet(input_shapes=(32,32,3),nb_class=NB_CLASSES)
model.summary()
model.compile(loss="categorical_crossentropy",optimizer='adam',metrics=['acc'])
early_stopping = EarlyStopping(monitor='val_loss', patience=2)
history = model.fit(X_train,Y_train,
batch_size = BATCH_SIZE,
nb_epoch = NB_EPOCH,
verbose=VERBOSE,
validation_split=VALIDATION_SPLIT,
callbacks=[early_stopping]
)
score = model.evaluate(X_test,Y_test,verbose = VERBOSE)
print("")
print("====================================")
print("====================================")
print(score[0])
print(score[1])
print("====================================")
print("====================================")
#save model
model.save("my_model"+str(score[1])+".h5")
#show the data in history
print(history.history.keys())
#summarize history for accuracy
plt.plot(history.history["acc"])
plt.plot(history.history["val_acc"])
plt.title("Model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train","test"],loc="upper left")
#summarize history for loss
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train","test"],loc="upper left")
plt.savefig("Performance:"+str(score[1])+".jpg")