深度學習中基於tensorflow_slim進行復雜模型訓練二之tensorflow_slim的使用
上篇部落格主要介紹了tensorflow_slim的基本模組,本篇主要介紹一下如何使用該模組訓練自己的模型。主要分為資料轉化,資料讀取,資料預處理,模型選擇,訓練引數設定,構建pb檔案,固化pb檔案中的引數幾部分。
一、資料轉化:
主要目的是將圖片轉化為TFrecords檔案,該部分屬於資料的預處理階段,可以參考datasets中的download_and_conver_flower中的run函式實現。具體關於如何使用將會在後續介紹。
二、資料讀取
該部分主要是在datasets中新建一個檔案並將其命名為自己的名字,例如命名為emotion.py,然後將flowers.py中的內容複製到新建的檔案中,並對以下部分進行修改:
1. _FILE_PATTERN = 'emotion_%s_*.tfrecord' 表示tfrecord檔名的格式
2. SPLITS_TO_SIZES = {'train': 18534, 'validation': 8331}表示用於訓練和測試的資料個數
3. _NUM_CLASSES = 5,訓練資料的類數,涉及到網路模型最後一層的輸出。
最後需要在dataset_factory中增加自己新建的資料對映。
datasets_map = {
'emotion': emotion,
}
三、資料增強
該過程主要是對讀取的資料進行資料增強,可以有兩種方式:1. 採用現有的增強模式(因為資料增強的大部分操作都是一樣的),2. 構建自己的增強方式(可以使模型訓練的時候傳入的引數較統一)。
對於第二種方式依然需要構建新的資料夾,然後複製一個內容進行修改或者完全自己書寫。本次採用的是複製cifarnet_preprocessing.py的內容進行修改得到的。具體修改的地方如下:將
distorted_image = tf.random_crop(image, [output_height, output_width, 3]) 改為
distorted_image = tf.image.resize_images(image, [output_height, output_width], method=1)
主要是為了避免需要的圖片比還未裁剪的小導致無法進行裁剪的錯誤
然後在preprocessing.py中增加新的對映:
preprocessing_fn_map = {
'emotion': emotion_preprocessing,
}
四、模型選擇
在nets中選擇出自己需要使用的模型,並下載對應訓練好的模型.ckpt檔案,具體的下載地址可以參考README.md檔案(以inception_v3模型為例)[inception_v3_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
五、訓練引數設定
準備工作完成後就是使用train_image_classifier.py對自己的資料進行訓練,該部分涉及到較多的引數,具體設定如下:
1. tf.app.flags.DEFINE_string( 'dataset_name', 'emotion', 'The name of the dataset to load.')
表示資料的名字,在讀取資料和資料增強的時候該值相當於是map中的key,根據該值找到對應的讀取和增強的指令碼。
tf.app.flags.DEFINE_string('dataset_split_name', 'train', 'The name of the train/test split.')
表示資料的作用,用來train還是validation,該值主要是產生emotion_train*.tfrecord的形式存於data_sources中,便於在後面的讀取資料時使用
if '*' in data_sources or '?' in data_sources or '[' in data_sources: data_files = gfile.Glob(data_sources)
的方式對資料進行讀取。
tf.app.flags.DEFINE_string( 'dataset_dir', "data_to_tfrecord", 'The directory where the dataset files are stored.')
表示tfrecord資料的儲存路徑,構建data結構讀取資料時使用
tf.app.flags.DEFINE_integer( 'labels_offset', 0, 'An offset for the labels in the dataset. This flag is primarily used to evaluate the VGG and ResNet architectures which do not use a background class for the ImageNet dataset.') 表示標籤的偏移量,即預設標籤是從0開始,假如偏移2,那麼標籤就會從2開始,一般情況下選擇預設的值即可。
tf.app.flags.DEFINE_string('model_name', 'inception_v3', 'The name of the architecture to train.')用於進行二次訓練的模型名字,主要依賴於net_factory中的map是如何寫的。
tf.app.flags.DEFINE_string( 'preprocessing_name', 'emotion', 'The name of the preprocessing to use. If left as `None`, then the model_name flag is used.') 表示採用的預處理方式,主要依賴於preprocessing_factory中的map
tf.app.flags.DEFINE_integer('batch_size', 32, 'The number of samples in each batch.') 在將資料進行批處理時每批資料的多少
tf.app.flags.DEFINE_integer( 'train_image_size', 299, 'Train image size') 模型輸入的圖片大小
tf.app.flags.DEFINE_integer('max_number_of_steps', 50000, 'The maximum number of training steps.') 表示訓練的步數
tf.app.flags.DEFINE_integer( 'log_every_n_steps', 10, 'The frequency with which logs are print.') log的輸出頻率,即每執行多少步輸出一個log
tf.app.flags.DEFINE_integer( 'save_summaries_secs', 100, 'The frequency with which summaries are saved, in seconds.')
表示儲存summaries的頻率
tf.app.flags.DEFINE_integer('save_interval_secs', 600, 'The frequency with which the model is saved, in seconds.') 表示儲存模型的頻率
tf.app.flags.DEFINE_float( 'weight_decay', 0.00004, 'The weight decay on the model weights.') 表示為了避免過擬合採用正則化的係數
tf.app.flags.DEFINE_string( 'train_dir', 'train_result', 'Directory where checkpoints and event logs are written to.')表示訓練引數儲存的地方
tf.app.flags.DEFINE_string( 'checkpoint_path', "pre_trained_check/inception_v3_2016_08_28/inception_v3.ckpt", 'The path to a checkpoint from which to fine-tune.') 表示提前處理好的模型引數儲存的地方
tf.app.flags.DEFINE_string('checkpoint_exclude_scopes', "InceptionV3/Logits,InceptionV3/AuxLogits",'Comma-separated list of scopes of variables to exclude when restoring from a checkpoint.')模型中不用恢復的節點,一般均為模型的輸出層,因為輸出層需要結合自己實際的類進行訓練確定資料的輸出大小,當該值為空時,則表示所有的變數均恢復。
tf.app.flags.DEFINE_string('trainable_scopes', None, 'Comma-separated list of scopes to filter the set of variables to train. By default, None would train all the variables.') 表示再次訓練的節點,None表示所有的都參與訓練。
tf.app.flags.DEFINE_string( 'learning_rate_decay_type', 'exponential', 'Specifies how the learning rate is decayed. One of "fixed", "exponential", or "polynomial"') 表示學習率衰減的方式。
對於該模組在使用中涉及到的其他引數均使用預設的即可。
在使用指令碼時有時候會報出部分操作無法在GPU上執行的錯誤,此時train的上面增加config = tf.ConfigProto(allow_soft_placement=True)表示當無法採用GPU計算時使用cpu進行。並將該引數傳遞給train。
五、構建pb檔案
此時直接使用export_interence_graph.py可以將模型結構變成.pb的,涉及的引數如下:
tf.app.flags.DEFINE_string( 'model_name', 'inception_v3', 'The name of the architecture to save.') 表示要呼叫的模型結構
tf.app.flags.DEFINE_boolean( 'is_training', False, 'Whether to save out a training-focused version of the model.') 表示在模型中的引數是否用來進行訓練
tf.app.flags.DEFINE_integer( 'image_size', 299, 'The image size to use, otherwise use the model default_image_size.')定義一個輸入佔位符的二三維大小
tf.app.flags.DEFINE_string('dataset_name', 'emotion', 'The name of the dataset to use with the model.') 主要根據傳入的名字確定一個對應的資料集,確定其num_class的值用於構建模型結構
tf.app.flags.DEFINE_integer( 'labels_offset', 0, 'An offset for the labels in the dataset. This flag is primarily used to evaluate the VGG and ResNet architectures which do not use a background class for the ImageNet dataset.')偏移量,用來構建模型的時候會用到
tf.app.flags.DEFINE_string( 'output_file', 'train_pb/motion_inception_v3_graph.pb', 'Where to save the resulting file to.')輸出的.pb檔名字和儲存地方
tf.app.flags.DEFINE_integer( 'batch_size', None,'Batch size for the exported model. Defaulted to "None" so batch size can ')定義輸入佔位符的第一維度大小。
六、在模型結構中放入自己訓練的結果並固化
該過程實現的原理是先讀入一個結構圖,然後在使用saver.restore()恢復圖中對應
引數的值,最後再儲存。
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.training import saver as saver_lib
# 定義一些引數
input_graph = 'train_pb\\test_1.pb'
output_graph = 'train_pb\\test_2.pb'
input_checkpoint = 'train_result\\model.ckpt-20'
output_node_names = 'InceptionV3/Predictions/Reshape_1'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(input_graph, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
except KeyError:
continue
var_list[key] = tensor
saver = tf.train.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,
output_node_names.split(","))
with tf.gfile.FastGFile(output_graph, mode='wb') as f:
f.write(constant_graph.SerializeToString())
至此,就完成了模型訓練和固化,然後可以根據具體需要自行進行使用。