1. 程式人生 > >CIFAR-10 圖像識別

CIFAR-10 圖像識別

producer 位置 改變 依次 裏的 success check alt shel

零、學習目標
  1. tensorflow 數據讀取原理
  2. 深度學習數據增強原理

一、CIFAR-10數據集簡介

是用於普通物體識別的小型數據集,一共包含 10個類別RGB彩×××片(包含:(飛機、汽車、鳥類、貓、鹿、狗、蛙、馬、船、卡車)。圖片大小均為 3232像素**,數據集中一共有 50000 張訓練圖片和 1000*** 張測試圖片。部分代碼來自於tensorflow官方,以下表格列出了所需的官方代碼。

文件 用途
cifar10.py 建立CIFAR-1O預測模型
cifar10_input.py 在tensorflow中讀入CIFAR-10訓練圖片
cifar10_input_test.py cifar10_input 的測試用例文件
cifar10_train.py 使用單個GPU或CPU訓練模型
cifar10_train_multi_gpu.py 使用多個gpu訓練模型
cifar10_eval.py 在測試集上測試模型的性能

二、下載CIFAR-10數據

在工程根目錄創建 cifar10_download.py ,輸入如下代碼創建下載數據的程序:

# 引入當前目錄中已經編寫好的cifar10模塊
import cifar10
# 引入tensorflow
import tensorflow as tf

# 定義全局變量存儲器,可用於命令行參數的處理
# tf.app.flags.FLAGS 是tensorflow 內部的一個全局變量存儲器
FLAGS = tf.app.flags.FLAGS
# 在cifar10 模塊中預先定義了cifar-10的數據存儲路徑,修改數據存儲路徑
FLAGS.data_dir = ‘cifar10_data/‘
# 如果數據不存在,則下載
cifar10.maybe_download_and_extract()

執行完這段代碼後,CIFAR-10數據集會下載到目錄 cifar10_data 目錄下。默認的存儲路徑書 tmp/cifar10_data,定義在代碼文件cifar10.py中,位置大約在53行附近。
修改完數據存儲路徑後,通過 cifar10.maybe_download_and_extract() 來下載數據,下載期間如果數據存在於數據文件夾中則跳過下載數據,反之下載數據。下載成功後會提示 Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes.
下載完成後,cifar10_data/cifar-10-batches-bin 中將出現8個文件,名稱和用途如下表:

文件名 用途
batches.meta.txt 存儲每個類別的英文名
data_batch_1.bin、......、data_batch_5.bin CIFAR-10的五個訓練集,每個訓練集用二進制格式存儲了10000張32*32的彩×××像和圖相對應的標簽,沒個樣本由3073個字節組成,第一個字節未標簽,剩下的字節未圖像數據
test_batch.bin 存儲1000張用於測試的圖像和對應的標簽
readme.html 數據集介紹文件

三、TensorFlow 讀取數據的機制

  1. 普通方式
    將硬盤上的數據讀入內存中,然後提供給CPU或者GPU處理
  2. 內存隊列方式
    普通方式讀取數據會出現GPU或CPU在一段時間內存在空閑,導致運算效率降低。利用內存隊列,將數據讀取和計算放在兩個線程中,讀取線程只需向內存隊列中讀入文件,而計算線程只用從內存隊列中讀取計算需要的數據,這樣就解決了GPU或者CPU的空閑問題。
  3. 文件名隊列+內存隊列
    TensorFlow采用 文件名隊列+內存隊列,這種方式可以很好的管理epoch(註1)和避免計算單元的空閑問題。舉個例子,假設有三個數據文件要執行一次epoch,那麽就在文件名隊列中放入這三個數據文件各一次,並且在最後放入的數據文件後面標註隊列結束。內存隊列依次從文件名隊列的頂部讀取數據文件,讀到結束標記後就會自動拋出異常,捕獲這個異常後程序就可以結束。如果是執行N次epoch,那麽就把每個數據文件放入文件名隊列N次。

    註1:
    對於數據集來說,運行一次epoch就是將數據集裏的所有數據完整的計算一遍,以此類推運行N次epoch就是將數據集裏的所有數據完整的計算N遍

四、創建文件名隊列和內存隊列

  1. 創建文件名隊列
    利用tensorflow的 tf.train.string_input_producer()(註2) 函數。給函數傳入一個文件名列表,系統將會轉換未文件名隊列。tf.train.string_input_producer() 函數有兩個重要的參數,分別是 num_epochsshuffle ,num_epochs表示epochs數,shuffle表示是否打亂文件名隊列內文件的順序,如果是True表示不按照文件名列表添加的順序進入文件名隊列,如果是Flase表示按照文件名列表添加的順序進入文件名隊列。
  2. 創建內存隊列
    在tensorflow中不手動創建內存隊列,只需使用 reader對象從文件名隊列中讀取數據就可以了。

    註2:
    使用tf.train.string_input_producer() 創建完文件名隊列後,文件名並沒有被加入到隊列中,如果此時開始計算,會導致整個系統處於阻塞狀態。
    在創建完文件名隊列後,應調用 tf.train.start_queue_runners 方法才會啟動文件名隊列的填充,整個程序才能正常運行起來。

  3. 代碼
import tensorflow as tf

# 新建session
with tf.Session() as sess:
    # 要讀取的三張圖片
    filename = [‘img/1.jpg‘, ‘img/2.jpg‘, ‘img/3.jpg‘]
    # 創建文件名隊列
    filename_queue = tf.train.string_input_producer(filename, num_epochs=5, shuffle=False)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    # 初始化變量(epoch)
    tf.local_variables_initializer().run()
    threads = tf.train.start_queue_runners(sess=sess)
    i = 0
    while True:
        i += 1
        # 獲取圖片保存數據
        image_data = sess.run(value)
        with open(‘read/test_%d.jpg‘ % i, ‘wb‘) as f:
            f.write(image_data)

五、數據增強

對於圖像數據來說,數據增強方法就是利用平移、縮放、顏色等變換增大訓練集樣本個數,從而達到更好的效果(註3),使用數據增強可以大大提高模型的泛化能力,並且能夠預防過擬合。
常用的圖像數據增強方法如下表

方法 說明
平移 將圖像在一定尺度範圍內平移
旋轉 將圖像在一定角度範圍內旋轉
翻轉 水平翻轉或者上下翻轉圖片
裁剪 在原圖上裁剪出一塊
縮放 將圖像在一定尺度內放大或縮小
顏色變換 對圖像的RGB顏色空間進行一些變換
噪聲擾動 給圖像加入一些人工生成的噪聲

註3:
使用數據增強的方法前提是,這些數據增強方法不會改變圖像的原有標簽。比如數字6的圖片,經過上下翻轉之後就變成了數字9的圖片。

六、CIFAR-10識別模型

建立模型的代碼在cifar10.py文件額inference函數中,代碼在這裏不進行詳解,讀者可以去閱讀代碼中的註釋。
這裏我們通過以下命令訓練模型:

python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/

這段命令中 --data_dir cifar10_data/ 表示數據保存的位置, --train_dir cifar10_train/ 表示保存模型參數和訓練時日誌信息的位置

七、查看訓練進度

在訓練的時候我們往往需要知道損失的變化和每層的訓練情況,這個時候我們就會用到tensorflow提供的 TensorBoard。打開一個新的命令行,輸入如下命令:

tensorboard --logdir cifar10_train/

其中 --logdir cifar10_train/ 表示模型訓練日誌保存的位置,運行該命令後將會在命令行看到類似如下的內容
技術分享圖片
在瀏覽器上輸入顯示的地址,即可訪問TensorBoard。簡單解釋一下常用的幾個標簽:

標簽 說明
total_loss_1 loss 的變化曲線,變化曲線會根據時間實時變化
learning_rate 學習率變化曲線
global_step 美妙訓練步數的情況,如果訓練速度變化較大,或者越來越慢,就說明程序有可能存在錯誤

八、檢測模型的準確性

在命令行窗口輸入如下命令:

python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/

--data_dir cifar10_data/ 表 示 CIFAR-10 數據集的存儲位置 。
--heckpoint_dir cifar1O_train/ 則表示程序模型保存在 cifar10_train/文件夾下。
--eval_dir cifar10_eval/ 指定了一個保存測試信息的文件夾
輸入以下命令,在TensorBoard上查看準確率歲訓練步數的變化情況:

tensorboard --logdir cifar10_eval/ --port 6007

在瀏覽器中輸入:http://127.0.0.1:6007,展開 Precision @ 1 選項卡,就可以看到準確率隨訓練步數變化的情況。

九、代碼下載

Git地址:https://gitee.com/bugback/ai_learning.git
百度網盤:https://pan.baidu.com/s/17HdfI2R9gsOMKi4pgundSA

CIFAR-10 圖像識別