1. 程式人生 > >使用VGG模型自定義影象分類任務

使用VGG模型自定義影象分類任務

前言

網上關於VGG模型的文章有很多,有介紹演算法本身的,也有程式碼實現,但是很多程式碼只給出了模型的結構實現,並不包含資料準備的部分,這讓人很難愉快的將程式碼遷移自己的任務中。為此,這篇部落格接下來圍繞著如何使用VGG實現自己的影象分類任務,從資料準備到實驗驗證。程式碼基於Python與TensorFlow實現,模型結構採用VGG-16,並且將很少的出現演算法和理論相關的東西。

資料準備

下載資料和轉換程式碼

大多數人自己的訓練資料,一般都是傳統的圖片形式,如.jpg,.png等等,而影象分類任務的話,這些圖片的天然組織形式就是一個類別放在一個資料夾裡,那麼有啥大眾化的資料集是這樣的組織形式呢?TensorFlow的FlowersData,它下載下來是這個樣子:
這裡寫圖片描述


一共有五類,每一類中都有幾百張圖,我們把這些資料組織成TFrecord形式,對應的部落格在這裡,原始碼的github在這裡,FlowersData資料集在這裡
有上面這三個東西之後,就可以生成TFrecord檔案了。

組織圖片資料

首先將FlowersData資料夾下的資料分成兩個部分,訓練資料和測試資料,我把原檔案五個類別中都拿出大概100張圖左右,資料的構成和路徑如下:
這裡寫圖片描述

生成訓練TFrecord

#圖片路徑
cwd = 'F:\\flowersdata\\trainimages\\'
#檔案路徑
filepath = 'F:\\flowersdata\\tfrecord\\
train\\'
classes=['daisy',
         'dandelion',
         'roses',
         'sunflowers',
         'tulips']
#tfrecords格式檔名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
#tfrecords格式檔名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)

生成效果:
這裡寫圖片描述

生成預測TFrecord

#圖片路徑
cwd = 'F:\\flowersdata\\testimages\\'
#檔案路徑
filepath = 'F:\\flowersdata\\tfrecord\\test\\'
classes=['daisy',
         'dandelion',
         'roses',
         'sunflowers',
         'tulips']
#tfrecords格式檔名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
#tfrecords格式檔名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)

生成效果:
這裡寫圖片描述

訓練模型

初始權重與原始碼下載

VGG-16的初始權重我上傳到了百度雲,在這裡下載;
VGG-16原始碼我上傳到了github,在這裡下載;

在原始碼中:
train_and_val.py檔案是最終要執行的檔案,它定了訓練和預測的過程;
input_data.py是將上一步中生成的TFRecord檔案組織成batch的過程;
VGG.py定義了VGG-16的網路結構;
tool.py是最底層,定義了一些卷積池化等操作。

訓練模型

train_and_val.py檔案修改:

if __name__=="__main__":
    train()
    #evaluate()

根據自己的路徑修改:

#初始權重路徑
pre_trained_weights = 'vgg16_pretrain/vgg16.npy'
#訓練資料路徑
train_data_dir = 'F:\\flowersdata\\tfrecord\\train\\traindata.tfrecords*'
    test_data_dir = 
#預測資料路徑
'F:\\flowersdata\\tfrecord\\test\\testdata.tfrecords*'
#訓練生成檔案路徑
train_log_dir = 'logs/train/'
#預測生成檔案路徑
val_log_dir = 'logs/val/'

根據自己的視訊記憶體容量修改:

IMG_W = 224
IMG_H = 224
BATCH_SIZE = 8

訓練過程每50個step列印loss;
每200個step計算一個batch中的準確率;
每1000個step儲存一次權重。

預測

train_and_val.py檔案修改:

if __name__=="__main__":
    #train()
    evaluate()
#訓練過程中生成的權重
log_dir = 'logs/train/'
#預測資料集路徑
test_data_dir = 'F:\\flowersdata\\tfrecord\\test\\testdata.tfrecords*'
#用於生成tf檔案的圖片數量
n_test = 502

列印測試樣本總數;
列印正確預測的樣本總數;
列印top_1。