1. 程式人生 > >MNIST資料集手寫體識別(SEQ2SEQ實現)

MNIST資料集手寫體識別(SEQ2SEQ實現)

github部落格傳送門
csdn部落格傳送門

本章所需知識:

  1. 沒有基礎的請觀看深度學習系列視訊
  2. tensorflow
  3. Python基礎

    資料下載連結:

    深度學習基礎網路模型(mnist手寫體識別資料集)

MNIST資料集手寫體識別(CNN實現)

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data  # 匯入下載資料集手寫體
mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True)


class SEQ2SEQNet:  # 建立一個SEQ2SEQNet類
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='input_x')  # 建立資料佔位符
        self.y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='input_y')  # 建立標籤佔位符

        self.fc1_w = tf.Variable(tf.truncated_normal(shape=[128, 10], dtype=tf.float32, stddev=tf.sqrt(1 / 10)))  # 定義 輸出層/全連結層 w
        self.fc1_b = tf.Variable(tf.zeros(shape=[10], dtype=tf.float32))  # 定義 輸出層/全連結層 偏值b

    # 前向計算
    def forward(self):
        # 編碼
        with tf.variable_scope('encode'):  # 建立一個變數空間 encode
            self.encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(128)  # 建立128個LSTM的RNN結構(細胞結構)
            self.encoder_states = self.encoder_cell.zero_state(100, dtype=tf.float32)  # 初始化細胞的狀態為 0, 傳入初始化批次 和資料型別
            self.encoder_output, self.encoder_state = tf.nn.dynamic_rnn(self.encoder_cell, self.x, initial_state=self.encoder_states, time_major=False)  # 將細胞cell 和資料 self.x 初始化狀態傳入RNN細胞結構 獲得兩個返回值 output 和 狀態state
            self.flat = tf.transpose(self.encoder_output, [1, 0, 2])[-1]  # 取rnn_output的輸出狀態的 每個輸出的最後一行 (相當於 self.rnn_ouput[:, -1, :])
            self.flat1 = tf.expand_dims(self.flat, axis=1)  # 增加了一個維度
            self.flat2 = tf.tile(self.flat1, [1, 4, 1])  # 將增加的那個維度進行 複製為 4行 不復制也行 reshape為 NSV結構[批次, 步長, 資料]也行.

        # 解碼
        with tf.variable_scope('decode'):  # 建立一個變數空間 decode
            self.decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(128)  # 建立128個LSTM的RNN結構
            self.decoder_states = self.decoder_cell.zero_state(100, dtype=tf.float32)  # 初始化細胞的狀態為 0, 傳入初始化批次 和資料型別
            self.decoder_output, self.decoder_state = tf.nn.dynamic_rnn(self.decoder_cell, self.flat2, initial_state=self.decoder_states, time_major=False)  # 將細胞cell 和資料 self.flat2 初始化狀態傳入RNN細胞結構 獲得兩個返回值 output 和 狀態state
            self.flat3 = tf.transpose(self.decoder_output, [1, 0, 2])[-1]  # 同上
            self.fc_y = tf.nn.relu(tf.matmul(self.flat3, self.fc1_w)+self.fc1_b)  # 全連結層
            self.output = tf.nn.softmax(self.fc_y)  # softmax分類
    
    # 後向計算
    def backword(self):
        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.fc_y))  # 定義損失, softmax交叉熵
        self.opt = tf.train.AdamOptimizer().minimize(self.cost)  # 使用AdamOptimizer優化損失
    
    # 計算測試集識別精度
    def acc(self):
        # 將預測值 output 和 標籤值 self.y 進行比較
        self.acc1 = tf.equal(tf.argmax(self.output, 1), tf.argmax(self.y, 1))
        #  最後對比較出來的bool值 轉換為float32型別後 求均值就可以看到滿值為 1的精度顯示
        self.accaracy = tf.reduce_mean(tf.cast(self.acc1, dtype=tf.float32))


if __name__ == '__main__':
    net = SEQ2SEQNet()  # 啟動tensorflow繪圖的SEQ2SEQNet
    net.forward()  # 啟動前向計算
    net.backward()  # 啟動後向計算
    net.acc()  # 啟動精度計算
    init = tf.global_variables_initializer()  # 定義初始化tensorflow所有變數操作
    with tf.Session() as sess:  # 建立一個Session會話
        sess.run(init)  # 執行init變數內的初始化所有變數的操作
        for i in range(10000):  # 訓練10000次
            ax, ay = mnist.train.next_batch(100)  # 從mnist資料集中取資料出來 ax接收圖片 ay接收標籤
            ax_batch = ax.reshape(-1, 28, 28)  # 將取出的 圖片資料 reshape成 NSV 結構
            loss, output, accaracy, _ = sess.run(fetches=[net.cost, net.output, net.accaracy, net.opt], feed_dict={net.x: ax_batch, net.y: ay})  # 將資料喂進編碼網路
            # print(loss)  # 列印損失
            # print(accaracy)  # 列印訓練精度
            if i % 100 == 0:  # 每訓練100次
                test_ax, test_ay = mnist.test.next_batch(100)  # 則使用測試集對當前網路進行測試
                test_ax_batch = test_ax.reshape(-1, 28, 28)  # 將取出的 圖片資料 reshape成 NSV 結構
                test_output = sess.run(fetches=net.output, feed_dict={net.x: test_ax_batch})  # 將資料喂進編碼網路  接收一個output值
                test_acc = sess.run(tf.equal(tf.argmax(test_output, 1), tf.argmax(test_ay, 1)))  # 對output值和標籤y值進行求比較運算
                test_accaracy = sess.run(tf.reduce_mean(tf.cast(test_acc, dtype=tf.float32)))  # 求出精度的準確率進行列印
                print(test_accaracy)  # 列印當前測試集的精度

最後附上訓練截圖:

SEQ2SEQ