1. 程式人生 > >ImageNet打造自己的影象識別

ImageNet打造自己的影象識別

一、原理
      在自己的資料集上訓練一個新的深度學習模型時,一般採取在預訓練ImageNet上進行微調的方法。什麼是微調?這裡以VGG16為例進行講解。
      VGG16網路結構:http://ethereon.github.io/netscope/#/preset/vgg-16
如下圖:
在這裡插入圖片描述

在這裡插入圖片描述

      VGG16的結構為卷積+全連線層。卷積層分為5個部分共13層,即conv1~conv5。還有三層全連線層,即fc6、fc7、fc8。卷積層加上全連線層合起來一共為16層。如果要將VGG16的結構用於一個新的資料集,首先要去掉fc8這一層。原因是fc8層的輸入是fc7層的特徵,輸出是1000類的概率,這1000類正好對應了ImageNet模型中的1000個類別。在自己的資料中,類別數一般不是1000類,因此fc8層的結構在此時是不適用的。必須將fc8層去掉,重新採用符合資料集類別數的全連線層,作為新的fc8.比如資料集為5類,那麼新的fc8的輸出也應當是5類。

      此外,在訓練的時候,網路的引數的初始值並不是隨機化生成的,而是採用VGG16在ImageNet上已經訓練好的引數作為訓練的初始值。這樣做的原因在於,在ImageNet資料集上訓練過的VGG16的引數已經包含了大量有用的卷積過濾器,與其從零開始初始化VGG16的所有引數,不如使用已經訓練好的引數當作訓練的起點。這樣做不僅可以節約大量訓練時間,而且有助於分類起效能的提高。

      載入VGG16的引數後,就可以開始訓練了。此時需要指定訓練層數的範圍。一般來說,可以選擇以下幾種範圍進行訓練:

  • 只訓練fc8.訓練範圍一定要包含fc8這一層。之前說過,fc8的結構被調整過,因此它的引數不能直接從ImageNet預訓練模型中取得。可以只訓練fc8,保持其他層的引數不動。這就相當於將VGG16當作一個特徵提取器,用fc7層提取的特徵做一個softmax模型分類。這樣做的好處是訓練速度塊,但往往效能不會太好。
  • 訓練所有引數。還可以對網路中的所有引數進行訓練,這種方法的訓練速度可能比較慢,但是能取得較高的效能,可以充分發揮深度模型的威力。
  • 訓練部分引數。通常是固定淺層引數不變,訓練深層引數。如固定conv1、conv2部分的引數不訓練,只訓練conv3、conv4、conv5、fc6、fc7、fc8的引數。

二、使用Tensorflow Slim微調模型
      slim是google公司公佈的一個影象分類工具包,不僅定義了一些方便的介面,還提供了很多ImageNet資料集上常用的網路結構和預訓練模型。包括VGG16\VGG19、Inception v1~v4、ResNet 50、ResNet101、MobileNet在內大多數常用模型的結構以及預訓練模型,更多的模型會被持續新增進來。

  1. 下載Tensorflow Slim的原始碼

      git clone https://github.com/tensorflow/models.git

      找到models/research/slim資料夾。

  1. 資料準備
    將jpg格式樣本集合轉化為tfrecord格式。
    首先做資料準備的工作,一是將資料集切分為訓練集和驗證集,二是轉換為tfrecord格式。建立data_prepare目錄,建立目錄結構如下:
    在這裡插入圖片描述

      在data_prepare目錄下執行指令碼:

python  data_convert.py -t pic/ --train-shards 2 --validation-shards 2 --num-threads 2 --dataset-name satellite

      其中dataset-name為給資料集起的名字。

      執行上述命令後,pic目錄生成如下5個檔案
在這裡插入圖片描述
      tfrecord檔案就是對應的訓練集和驗證集,另外還有label.txt,為類別對映關係。

  1. 定義新的dataset
    在slim/dataset中,定義所有可用的資料庫,前面定義的新的satellite資料集,在這裡也要定義對應的dataset。
    新建 satellite.py檔案,把 flowers.py複製到其中。如下
_FILE_PATTERN='satellite_%s_*.tfrecord'//改成自己的圖片的命名

SPLITS_TO_SIZES={‘train’:4800,'validation':1200}//訓練集和測試集的總數目

_NUM_CLASSES=6  //類別數目

修改圖片的預設格式

keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

修改完satellite.py檔案後,還要在dataset_factory.py註冊satellite資料庫

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
	'satellite':satellite,
}
  1. 準備訓練資料夾
    slim下新建satellite目錄,完成以下準備工作:
  • data目錄,把之前生成好的5個檔案複製進來
  • 新建一個空的train_dir目錄,用來儲存訓練過程中的日誌和模型。
  • 新建一個pretrained目錄,在slim的GitHubi頁面找到Inception-V3模型的下載地址http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz,下載並解壓後,得到 inception_v3.ckpt檔案,將該檔案複製到pretrained目錄下。

在這裡插入圖片描述

  1. 開始訓練
python train_image_classifier.py --train_dir=satellite\train_dir --dataset_name=satellite --dataset_split_name=train --dataset_dir=satellite\data --model_name=inception_v3 --checkpoint_path=satellite\pretrained\inception_v3.ckpt --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --max_number_of_steps=100000 --batch_size=32 --learning_rate=0.001 --learning_rate_decay_type=fixed --save_interval_secs=120 --save_summaries_secs=20 --log_every_n_steps=10 --optimizer=rmsprop --weight_decay=0.00004   

      trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits。trainable_scopes規定了在模型中微調變數的範圍。這裡的設定表示只對 InceptionV3/Logits, InceptionV3/AuxLogits兩個變數進行微調,其他變數都保持不動。
      InceptionV3/Logits, InceptionV3/AuxLogits是inception V3的末端層。只對最後一層分類層進行訓練,比如原來是1000類,現在訓練的只是2類。如果不設定trainable_scopes,就只會對模型中所有的引數進行訓練。

  1. 驗證模型
python eval_image_classifier.py --checkpoint_path=satellite/train_dir --eval_dir=satellite/eval_dir --dataset_name=satellite --dataset_split_name=validation --dataset_dir=satellite/data --model_name=inception_v3

最後顯示:eval/Recall_5[0.979166687]eval/Accuracy[0.561666667]
其中Accuracy為分類準確率,Recall_5表示Top 5的準確率,即在輸出的類別概率中,正確的類別只有落在前5就算對。

  1. Tensorboard視覺化
tensorboard –logdir satellite/train_dir