1. 程式人生 > 其它 >TF10——卷積神經網路搭建示例

TF10——卷積神經網路搭建示例

TF10——卷積神經網路搭建示例

用卷積神經網路訓練cifar10資料集,搭建一個一層卷積,兩層全連線的網路

  • 使用6個5*5卷積核,過2*2池化核,池化步長是2
  • 過128個神經元的全連線層
  • 由於cifar10是10分類,所以最後還要過一層十個神經元的全連線層

根據搭建卷積網路的八股口訣CBAPD ,搭建如下神經網路:

C:6個5*5的卷積核,步長是1,使用全零填充

B:使用批標準化

A:使用relu啟用函式

P:使用最大池化,池化核是2*2,池化步長是2,使用全零填充

D:把20%的神經元休眠

Flatten把卷積送過來的資料拉直

送入128個神經元的全連線,過relu啟用,過0.2的Dropout

過10個神經元的全連線,過softmax函式使輸出符合概率分佈

程式碼實現

由於網路相對複雜了,所以我用class類搭建網路結構

__init__函式中,準備出搭建神經網路要用到的每一層結構,卷積網路就是CBAPD,

  • C裡面有6個卷積核,每個卷積核都是5*5的尺寸,使用全零填充

  • B使用批標準化

  • A使用relu啟用函式

  • P使用最大池化,池化核是2*2,池化步長是2,使用全零填充

  • D把20%的神經元休眠

  • Flatten把卷積送過來的資料拉直

  • 送入128個神經元的全連線,過relu啟用

  • 把20%的神經元休眠

  • 過10個神經元的全連線,過softmax函式使輸出符合概率分佈

在call函式中,呼叫__init__函式裡搭建好的每層網路結構,從輸入到輸出過一次前向傳播,返回推理結果y

用六步法寫出的程式碼是這樣的:

原始碼:p27_cifar10_baseline.py

## import部分
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model
np.set_printoptions(threshold=np.inf)#解決輸出陣列時的省略情況

## train test 給訓練集與測試集輸入特徵、訓練集與測試集標籤
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

## class model
class Baseline(Model):
    def __init__(self):
        super(Baseline, self).__init__()
        self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same')  # 卷積層
        self.b1 = BatchNormalization()  # BN層
        self.a1 = Activation('relu')  # 啟用層
        self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')  # 池化層
        self.d1 = Dropout(0.2)  # dropout層

        self.flatten = Flatten()
        self.f1 = Dense(128, activation='relu')
        self.d2 = Dropout(0.2)
        self.f2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y

model = Baseline()

## model.compile 配置訓練方法,告訴訓練時選擇哪種優化器,選擇哪個損失函式,選擇哪種評測指標
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/Baseline.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

## model.fit執行訓練過程,告訴訓練集和測試集的輸入特徵和標籤,告知每個batch是多少,告知要迭代多少次資料集,告知測試集,告知多少次資料集迭代用測試集驗證準確率,使用回撥函式實現斷點續訓
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
## model.summary打印出網路的結構和引數
model.summary()
## 實現引數提取
# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

## acc/loss視覺化
###############################################    show   ###############################################

# 顯示訓練集和驗證集的acc和loss曲線
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

上述程式碼作為後一講後續內容的基準程式碼,在介紹經典卷積神經網路結構時,我們只替換class model模組的內容其餘程式碼不變

程式執行結果:

隨著迭代次數的增加,準確率在不斷的提高!

打印出了acc/loss視覺化效果

在weight.txt檔案裡記錄了所有的可訓練引數

我們可以在任何平臺復現出神經網路的前向傳播實現應用