1. 程式人生 > 程式設計 >使用TensorFlow-Slim進行影象分類的實現

使用TensorFlow-Slim進行影象分類的實現

參考 https://github.com/tensorflow/models/tree/master/slim

使用TensorFlow-Slim進行影象分類

準備

安裝TensorFlow

參考 https://www.tensorflow.org/install/

如在Ubuntu下安裝TensorFlow with GPU support,python 2.7版本

wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl

下載TF-slim影象模型庫

cd $WORKSPACE
git clone https://github.com/tensorflow/models/

準備資料

有不少公開資料集,這裡以官網提供的Flowers為例。

官網提供了下載和轉換資料的程式碼,為了理解程式碼並能使用自己的資料,這裡參考官方提供的程式碼進行修改。

cd $WORKSPACE/data
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar zxf flower_photos.tgz

資料集資料夾結構如下:

flower_photos
├── daisy
│  ├── 100080576_f52e8ee070_n.jpg
│  └── ...
├── dandelion
├── LICENSE.txt
├── roses
├── sunflowers
└── tulips

由於實際情況中我們自己的資料集並不一定把圖片按類別放在不同的資料夾裡,故我們生成list.txt來表示圖片路徑與標籤的關係。

Python程式碼:

import os

class_names_to_ids = {'daisy': 0,'dandelion': 1,'roses': 2,'sunflowers': 3,'tulips': 4}
data_dir = 'flower_photos/'
output_path = 'list.txt'

fd = open(output_path,'w')
for class_name in class_names_to_ids.keys():
  images_list = os.listdir(data_dir + class_name)
  for image_name in images_list:
    fd.write('{}/{} {}\n'.format(class_name,image_name,class_names_to_ids[class_name]))

fd.close()

為了方便後期檢視label標籤,也可以定義labels.txt:

daisy
dandelion
roses
sunflowers
tulips

隨機生成訓練集與驗證集:

Python程式碼:

import random

_NUM_VALIDATION = 350
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'

fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)

fd = open(train_list_path,'w')
for line in lines[_NUM_VALIDATION:]:
  fd.write(line)

fd.close()
fd = open(val_list_path,'w')
for line in lines[:_NUM_VALIDATION]:
  fd.write(line)

fd.close()

生成TFRecord資料:

Python程式碼:

import sys
sys.path.insert(0,'../models/slim/')
from datasets import dataset_utils
import math
import os
import tensorflow as tf

def convert_dataset(list_path,data_dir,output_dir,_NUM_SHARDS=5):
  fd = open(list_path)
  lines = [line.split() for line in fd]
  fd.close()
  num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    decode_jpeg_data = tf.placeholder(dtype=tf.string)
    decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data,channels=3)
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        output_path = os.path.join(output_dir,'data_{:05}-of-{:05}.tfrecord'.format(shard_id,_NUM_SHARDS))
        tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
        start_ndx = shard_id * num_per_shard
        end_ndx = min((shard_id + 1) * num_per_shard,len(lines))
        for i in range(start_ndx,end_ndx):
          sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
            i + 1,len(lines),shard_id))
          sys.stdout.flush()
          image_data = tf.gfile.FastGFile(os.path.join(data_dir,lines[i][0]),'rb').read()
          image = sess.run(decode_jpeg,feed_dict={decode_jpeg_data: image_data})
          height,width = image.shape[0],image.shape[1]
          example = dataset_utils.image_to_tfexample(
            image_data,b'jpg',height,width,int(lines[i][1]))
          tfrecord_writer.write(example.SerializeToString())
        tfrecord_writer.close()
  sys.stdout.write('\n')
  sys.stdout.flush()

os.system('mkdir -p train')
convert_dataset('list_train.txt','flower_photos','train/')
os.system('mkdir -p val')
convert_dataset('list_val.txt','val/')

得到的資料夾結構如下:

data
├── flower_photos
├── labels.txt
├── list_train.txt
├── list.txt
├── list_val.txt
├── train
│  ├── data_00000-of-00005.tfrecord
│  ├── ...
│  └── data_00004-of-00005.tfrecord
└── val
  ├── data_00000-of-00005.tfrecord
  ├── ...
  └── data_00004-of-00005.tfrecord

(可選)下載模型

官方提供了不少預訓練模型,這裡以Inception-ResNet-v2以例。

cd $WORKSPACE/checkpoints
wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
tar zxf inception_resnet_v2_2016_08_30.tar.gz

訓練

讀入資料

官方提供了讀入Flowers資料集的程式碼models/slim/datasets/flowers.py,同樣這裡也是參考並修改成能讀入上面定義的通用資料集。

把下面程式碼寫入models/slim/datasets/dataset_classification.py。

import os
import tensorflow as tf
slim = tf.contrib.slim

def get_dataset(dataset_dir,num_samples,num_classes,labels_to_names_path=None,file_pattern='*.tfrecord'):
  file_pattern = os.path.join(dataset_dir,file_pattern)
  keys_to_features = {
    'image/encoded': tf.FixedLenFeature((),tf.string,default_value=''),'image/format': tf.FixedLenFeature((),default_value='png'),'image/class/label': tf.FixedLenFeature(
      [],tf.int64,default_value=tf.zeros([],dtype=tf.int64)),}
  items_to_handlers = {
    'image': slim.tfexample_decoder.Image(),'label': slim.tfexample_decoder.Tensor('image/class/label'),}
  decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,items_to_handlers)
  items_to_descriptions = {
    'image': 'A color image of varying size.','label': 'A single integer between 0 and ' + str(num_classes - 1),}
  labels_to_names = None
  if labels_to_names_path is not None:
    fd = open(labels_to_names_path)
    labels_to_names = {i : line.strip() for i,line in enumerate(fd)}
    fd.close()
  return slim.dataset.Dataset(
      data_sources=file_pattern,reader=tf.TFRecordReader,decoder=decoder,num_samples=num_samples,items_to_descriptions=items_to_descriptions,num_classes=num_classes,labels_to_names=labels_to_names)

構建模型

官方提供了許多模型在models/slim/nets/。

如需要自定義模型,則參考官方提供的模型並放在對應的資料夾即可。

開始訓練

官方提供了訓練指令碼,如果使用官方的資料讀入和處理,可使用以下方式開始訓練。

cd $WORKSPACE/models/slim
CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

不fine-tune把--checkpoint_path,--checkpoint_exclude_scopes和--trainable_scopes刪掉。

fine-tune所有層把--checkpoint_exclude_scopes和--trainable_scopes刪掉。

如果只使用CPU則加上--clone_on_cpu=True。

其它引數可刪掉用預設值或自行修改。

使用自己的資料則需要修改models/slim/train_image_classifier.py:

from datasets import dataset_factory

修改為

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name,FLAGS.dataset_split_name,FLAGS.dataset_dir)

修改為

dataset = dataset_classification.get_dataset(
  FLAGS.dataset_dir,FLAGS.num_samples,FLAGS.num_classes,FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
  'dataset_dir',None,'The directory where the dataset files are stored.')

後加入

tf.app.flags.DEFINE_integer(
  'num_samples',3320,'Number of samples.')

tf.app.flags.DEFINE_integer(
  'num_classes',5,'Number of classes.')

tf.app.flags.DEFINE_string(
  'labels_to_names_path','Label names file path.')

訓練時執行以下命令即可:

cd $WORKSPACE/models/slim
python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_dir=../../data/train \
  --num_samples=3320 \
  --num_classes=5 \
  --labels_to_names_path=../../data/labels.txt \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits

視覺化log

可一邊訓練一邊視覺化訓練的log,可看到Loss趨勢。

tensorboard --logdir train_logs/

驗證

官方提供了驗證指令碼。

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2

同樣,如果是使用自己的資料集,則需要修改models/slim/eval_image_classifier.py:

from datasets import dataset_factory

修改為

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name,350,'Label names file path.')

驗證時執行以下命令即可:

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_dir=../../data/val \
  --num_samples=350 \
  --num_classes=5 \
  --model_name=inception_resnet_v2

可以一邊訓練一邊驗證,,注意使用其它的GPU或合理分配視訊記憶體。

同樣也可以視覺化log,如果已經在視覺化訓練的log則建議使用其它埠,如:

tensorboard --logdir eval_logs/ --port 6007

測試

參考models/slim/eval_image_classifier.py,可編寫讀取圖片用模型進行推導的指令碼models/slim/test_image_classifier.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
  'master','','The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(
  'checkpoint_path','/tmp/tfmodel/','The directory where the model was written to or an absolute path to a '
  'checkpoint file.')

tf.app.flags.DEFINE_string(
  'test_path','Test image path.')

tf.app.flags.DEFINE_integer(
  'num_classes','Number of classes.')

tf.app.flags.DEFINE_integer(
  'labels_offset','An offset for the labels in the dataset. This flag is primarily used to '
  'evaluate the VGG and ResNet architectures which do not use a background '
  'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
  'model_name','inception_v3','The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
  'preprocessing_name','The name of the preprocessing to use. If left '
  'as `None`,then the model_name flag is used.')

tf.app.flags.DEFINE_integer(
  'test_image_size','Eval image size')

FLAGS = tf.app.flags.FLAGS


def main(_):
  if not FLAGS.test_list:
    raise ValueError('You must supply the test list with --test_list')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,num_classes=(FLAGS.num_classes - FLAGS.labels_offset),is_training=False)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,is_training=False)

    test_image_size = FLAGS.test_image_size or network_fn.default_image_size

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.Graph().as_default()
    with tf.Session() as sess:
      image = open(FLAGS.test_path,'rb').read()
      image = tf.image.decode_jpeg(image,channels=3)
      processed_image = image_preprocessing_fn(image,test_image_size,test_image_size)
      processed_images = tf.expand_dims(processed_image,0)
      logits,_ = network_fn(processed_images)
      predictions = tf.argmax(logits,1)
      saver = tf.train.Saver()
      saver.restore(sess,checkpoint_path)
      np_image,network_input,predictions = sess.run([image,processed_image,predictions])
      print('{} {}'.format(FLAGS.test_path,predictions[0]))

if __name__ == '__main__':
  tf.app.run()

測試時執行以下命令即可:

python test_image_classifier.py \
  --checkpoint_path=train_logs/ \
  --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg \
  --num_classes=5 \
  --model_name=inception_resnet_v2

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支援我們。