tensorflow-deeplab-resnet 原理及程式碼詳解
阿新 • • 發佈:2018-12-13
前言: 程式碼的model.py,network.py是建立深度學習網路的部分,這部分程式碼風格與Faster-RCNN_TF那個程式的風格非常相似,也很簡單,不再多做介紹。這裡主要介紹train.py、image_reader.py其他還有inference.py、utils.py、fine_tune.py就不做介紹了,比較簡單。
一、網路結構: 對程式碼的Network.py稍作修改,使得它打印出各層的網路的輸出如下:
因為單張1080顯示卡現存有限,我使用的batch設定為4,輸入圖片是384*384大小的,但是第一個卷積的步長設定為了2 conv1 (4, 192, 192, 64) bn_conv1 (4, 192, 192, 64) pool1 (4, 192, 192, 64) res2a_branch1 (4, 96, 96, 256) bn2a_branch1 (4, 96, 96, 256) res2a_branch2a (4, 96, 96, 64) bn2a_branch2a (4, 96, 96, 64) res2a_branch2b (4, 96, 96, 64) bn2a_branch2b (4, 96, 96, 64) res2a_branch2c (4, 96, 96, 256) bn2a_branch2c (4, 96, 96, 256) res2a_relu (4, 96, 96, 256) res2b_branch2a (4, 96, 96, 64) bn2b_branch2a (4, 96, 96, 64) res2b_branch2b (4, 96, 96, 64) bn2b_branch2b (4, 96, 96, 64) res2b_branch2c (4, 96, 96, 256) bn2b_branch2c (4, 96, 96, 256) res2b_relu (4, 96, 96, 256) res2c_branch2a (4, 96, 96, 64) bn2c_branch2a (4, 96, 96, 64) res2c_branch2b (4, 96, 96, 64) bn2c_branch2b (4, 96, 96, 64) res2c_branch2c (4, 96, 96, 256) bn2c_branch2c (4, 96, 96, 256) res2c_relu (4, 96, 96, 256) res3a_branch1 (4, 48, 48, 512) bn3a_branch1 (4, 48, 48, 512) res3a_branch2a (4, 48, 48, 128) bn3a_branch2a (4, 48, 48, 128) res3a_branch2b (4, 48, 48, 128) bn3a_branch2b (4, 48, 48, 128) res3a_branch2c (4, 48, 48, 512) bn3a_branch2c (4, 48, 48, 512) res3a_relu (4, 48, 48, 512) res3b1_branch2a (4, 48, 48, 128) bn3b1_branch2a (4, 48, 48, 128) res3b1_branch2b (4, 48, 48, 128) bn3b1_branch2b (4, 48, 48, 128) res3b1_branch2c (4, 48, 48, 512) bn3b1_branch2c (4, 48, 48, 512) res3b1_relu (4, 48, 48, 512) res3b2_branch2a (4, 48, 48, 128) bn3b2_branch2a (4, 48, 48, 128) res3b2_branch2b (4, 48, 48, 128) bn3b2_branch2b (4, 48, 48, 128) res3b2_branch2c (4, 48, 48, 512) bn3b2_branch2c (4, 48, 48, 512) res3b2_relu (4, 48, 48, 512) res3b3_branch2a (4, 48, 48, 128) bn3b3_branch2a (4, 48, 48, 128) res3b3_branch2b (4, 48, 48, 128) bn3b3_branch2b (4, 48, 48, 128) res3b3_branch2c (4, 48, 48, 512) bn3b3_branch2c (4, 48, 48, 512) res3b3_relu (4, 48, 48, 512) res4a_branch1 (4, 48, 48, 1024) bn4a_branch1 (4, 48, 48, 1024) res4a_branch2a (4, 48, 48, 256) bn4a_branch2a (4, 48, 48, 256) res4a_branch2b (4, 48, 48, 256) bn4a_branch2b (4, 48, 48, 256) res4a_branch2c (4, 48, 48, 1024) bn4a_branch2c (4, 48, 48, 1024) res4a_relu (4, 48, 48, 1024) res4b1_branch2a (4, 48, 48, 256) bn4b1_branch2a (4, 48, 48, 256) res4b1_branch2b (4, 48, 48, 256) bn4b1_branch2b (4, 48, 48, 256) res4b1_branch2c (4, 48, 48, 1024) bn4b1_branch2c (4, 48, 48, 1024) res4b1_relu (4, 48, 48, 1024) res4b2_branch2a (4, 48, 48, 256) bn4b2_branch2a (4, 48, 48, 256) res4b2_branch2b (4, 48, 48, 256) bn4b2_branch2b (4, 48, 48, 256) res4b2_branch2c (4, 48, 48, 1024) bn4b2_branch2c (4, 48, 48, 1024) res4b2_relu (4, 48, 48, 1024) res4b3_branch2a (4, 48, 48, 256) bn4b3_branch2a (4, 48, 48, 256) res4b3_branch2b (4, 48, 48, 256) bn4b3_branch2b (4, 48, 48, 256) res4b3_branch2c (4, 48, 48, 1024) bn4b3_branch2c (4, 48, 48, 1024) res4b3_relu (4, 48, 48, 1024) res4b4_branch2a (4, 48, 48, 256) bn4b4_branch2a (4, 48, 48, 256) res4b4_branch2b (4, 48, 48, 256) bn4b4_branch2b (4, 48, 48, 256) res4b4_branch2c (4, 48, 48, 1024) bn4b4_branch2c (4, 48, 48, 1024) res4b4_relu (4, 48, 48, 1024) res4b5_branch2a (4, 48, 48, 256) bn4b5_branch2a (4, 48, 48, 256) res4b5_branch2b (4, 48, 48, 256) bn4b5_branch2b (4, 48, 48, 256) res4b5_branch2c (4, 48, 48, 1024) bn4b5_branch2c (4, 48, 48, 1024) res4b5_relu (4, 48, 48, 1024) res4b6_branch2a (4, 48, 48, 256) bn4b6_branch2a (4, 48, 48, 256) res4b6_branch2b (4, 48, 48, 256) bn4b6_branch2b (4, 48, 48, 256) res4b6_branch2c (4, 48, 48, 1024) bn4b6_branch2c (4, 48, 48, 1024) res4b6_relu (4, 48, 48, 1024) res4b7_branch2a (4, 48, 48, 256) bn4b7_branch2a (4, 48, 48, 256) res4b7_branch2b (4, 48, 48, 256) bn4b7_branch2b (4, 48, 48, 256) res4b7_branch2c (4, 48, 48, 1024) bn4b7_branch2c (4, 48, 48, 1024) res4b7_relu (4, 48, 48, 1024) res4b8_branch2a (4, 48, 48, 256) bn4b8_branch2a (4, 48, 48, 256) res4b8_branch2b (4, 48, 48, 256) bn4b8_branch2b (4, 48, 48, 256) res4b8_branch2c (4, 48, 48, 1024) bn4b8_branch2c (4, 48, 48, 1024) res4b8_relu (4, 48, 48, 1024) res4b9_branch2a (4, 48, 48, 256) bn4b9_branch2a (4, 48, 48, 256) res4b9_branch2b (4, 48, 48, 256) bn4b9_branch2b (4, 48, 48, 256) res4b9_branch2c (4, 48, 48, 1024) bn4b9_branch2c (4, 48, 48, 1024) res4b9_relu (4, 48, 48, 1024) res4b10_branch2a (4, 48, 48, 256) bn4b10_branch2a (4, 48, 48, 256) res4b10_branch2b (4, 48, 48, 256) bn4b10_branch2b (4, 48, 48, 256) res4b10_branch2c (4, 48, 48, 1024) bn4b10_branch2c (4, 48, 48, 1024) res4b10_relu (4, 48, 48, 1024) res4b11_branch2a (4, 48, 48, 256) bn4b11_branch2a (4, 48, 48, 256) res4b11_branch2b (4, 48, 48, 256) bn4b11_branch2b (4, 48, 48, 256) res4b11_branch2c (4, 48, 48, 1024) bn4b11_branch2c (4, 48, 48, 1024) res4b11_relu (4, 48, 48, 1024) res4b12_branch2a (4, 48, 48, 256) bn4b12_branch2a (4, 48, 48, 256) res4b12_branch2b (4, 48, 48, 256) bn4b12_branch2b (4, 48, 48, 256) res4b12_branch2c (4, 48, 48, 1024) bn4b12_branch2c (4, 48, 48, 1024) res4b12_relu (4, 48, 48, 1024) res4b13_branch2a (4, 48, 48, 256) bn4b13_branch2a (4, 48, 48, 256) res4b13_branch2b (4, 48, 48, 256) bn4b13_branch2b (4, 48, 48, 256) res4b13_branch2c (4, 48, 48, 1024) bn4b13_branch2c (4, 48, 48, 1024) res4b13_relu (4, 48, 48, 1024) res4b14_branch2a (4, 48, 48, 256) bn4b14_branch2a (4, 48, 48, 256) res4b14_branch2b (4, 48, 48, 256) bn4b14_branch2b (4, 48, 48, 256) res4b14_branch2c (4, 48, 48, 1024) bn4b14_branch2c (4, 48, 48, 1024) res4b14_relu (4, 48, 48, 1024) res4b15_branch2a (4, 48, 48, 256) bn4b15_branch2a (4, 48, 48, 256) res4b15_branch2b (4, 48, 48, 256) bn4b15_branch2b (4, 48, 48, 256) res4b15_branch2c (4, 48, 48, 1024) bn4b15_branch2c (4, 48, 48, 1024) res4b15_relu (4, 48, 48, 1024) res4b16_branch2a (4, 48, 48, 256) bn4b16_branch2a (4, 48, 48, 256) res4b16_branch2b (4, 48, 48, 256) bn4b16_branch2b (4, 48, 48, 256) res4b16_branch2c (4, 48, 48, 1024) bn4b16_branch2c (4, 48, 48, 1024) res4b16_relu (4, 48, 48, 1024) res4b17_branch2a (4, 48, 48, 256) bn4b17_branch2a (4, 48, 48, 256) res4b17_branch2b (4, 48, 48, 256) bn4b17_branch2b (4, 48, 48, 256) res4b17_branch2c (4, 48, 48, 1024) bn4b17_branch2c (4, 48, 48, 1024) res4b17_relu (4, 48, 48, 1024) res4b18_branch2a (4, 48, 48, 256) bn4b18_branch2a (4, 48, 48, 256) res4b18_branch2b (4, 48, 48, 256) bn4b18_branch2b (4, 48, 48, 256) res4b18_branch2c (4, 48, 48, 1024) bn4b18_branch2c (4, 48, 48, 1024) res4b18_relu (4, 48, 48, 1024) res4b19_branch2a (4, 48, 48, 256) bn4b19_branch2a (4, 48, 48, 256) res4b19_branch2b (4, 48, 48, 256) bn4b19_branch2b (4, 48, 48, 256) res4b19_branch2c (4, 48, 48, 1024) bn4b19_branch2c (4, 48, 48, 1024) res4b19_relu (4, 48, 48, 1024) res4b20_branch2a (4, 48, 48, 256) bn4b20_branch2a (4, 48, 48, 256) res4b20_branch2b (4, 48, 48, 256) bn4b20_branch2b (4, 48, 48, 256) res4b20_branch2c (4, 48, 48, 1024) bn4b20_branch2c (4, 48, 48, 1024) res4b20_relu (4, 48, 48, 1024) res4b21_branch2a (4, 48, 48, 256) bn4b21_branch2a (4, 48, 48, 256) res4b21_branch2b (4, 48, 48, 256) bn4b21_branch2b (4, 48, 48, 256) res4b21_branch2c (4, 48, 48, 1024) bn4b21_branch2c (4, 48, 48, 1024) res4b21_relu (4, 48, 48, 1024) res4b22_branch2a (4, 48, 48, 256) bn4b22_branch2a (4, 48, 48, 256) res4b22_branch2b (4, 48, 48, 256) bn4b22_branch2b (4, 48, 48, 256) res4b22_branch2c (4, 48, 48, 1024) bn4b22_branch2c (4, 48, 48, 1024) res4b22_relu (4, 48, 48, 1024) res5a_branch1 (4, 48, 48, 2048) bn5a_branch1 (4, 48, 48, 2048) res5a_branch2a (4, 48, 48, 512) bn5a_branch2a (4, 48, 48, 512) res5a_branch2b (4, 48, 48, 512) bn5a_branch2b (4, 48, 48, 512) res5a_branch2c (4, 48, 48, 2048) bn5a_branch2c (4, 48, 48, 2048) res5a_relu (4, 48, 48, 2048) res5b_branch2a (4, 48, 48, 512) bn5b_branch2a (4, 48, 48, 512) res5b_branch2b (4, 48, 48, 512) bn5b_branch2b (4, 48, 48, 512) res5b_branch2c (4, 48, 48, 2048) bn5b_branch2c (4, 48, 48, 2048) res5b_relu (4, 48, 48, 2048) res5c_branch2a (4, 48, 48, 512) bn5c_branch2a (4, 48, 48, 512) res5c_branch2b (4, 48, 48, 512) bn5c_branch2b (4, 48, 48, 512) res5c_branch2c (4, 48, 48, 2048) bn5c_branch2c (4, 48, 48, 2048) res5c_relu (4, 48, 48, 2048) fc1_voc12_c0 (4, 48, 48, 2) fc1_voc12_c1 (4, 48, 48, 2) fc1_voc12_c2 (4, 48, 48, 2) fc1_voc12_c3 (4, 48, 48, 2)
下面仔細介紹下程式碼的具體實現:
from __future__ import print_function import argparse from datetime import datetime import os import sys import time import tensorflow as tf import numpy as np from deeplab_resnet import DeepLabResNetModel, ImageReader, decode_labels, inv_preprocess, prepare_label IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) BATCH_SIZE = 10 DATA_DIRECTORY = '/home/VOCdevkit' DATA_LIST_PATH = './dataset/train.txt' IGNORE_LABEL = 255 INPUT_SIZE = '321,321' LEARNING_RATE = 2.5e-4 MOMENTUM = 0.9 NUM_CLASSES = 21 NUM_STEPS = 20001 POWER = 0.9 RANDOM_SEED = 1234 RESTORE_FROM = './deeplab_resnet.ckpt' SAVE_NUM_IMAGES = 2 SAVE_PRED_EVERY = 1000 SNAPSHOT_DIR = './snapshots/' WEIGHT_DECAY = 0.0005 def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Number of images sent to the network in one step.") parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, help="Path to the directory containing the PASCAL VOC dataset.") parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, help="Path to the file listing the images in the dataset.") parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, help="The index of the label to ignore during the training.") parser.add_argument("--input-size", type=str, default=INPUT_SIZE, help="Comma-separated string with height and width of images.") parser.add_argument("--is-training", action="store_true", help="Whether to updates the running means and variances during the training.") parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, help="Base learning rate for training with polynomial decay.") parser.add_argument("--momentum", type=float, default=MOMENTUM, help="Momentum component of the optimiser.") parser.add_argument("--not-restore-last", action="store_true", help="Whether to not restore last (FC) layers.") parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, help="Number of classes to predict (including background).") parser.add_argument("--num-steps", type=int, default=NUM_STEPS, help="Number of training steps.") parser.add_argument("--power", type=float, default=POWER, help="Decay parameter to compute the learning rate.") parser.add_argument("--random-mirror", action="store_true", help="Whether to randomly mirror the inputs during the training.") parser.add_argument("--random-scale", action="store_true", help="Whether to randomly scale the inputs during the training.") parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, help="Random seed to have reproducible results.") parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, help="Where restore model parameters from.") parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, help="How many images to save.") parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, help="Save summaries and checkpoint every often.") parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, help="Where to save snapshots of the model.") parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, help="Regularisation parameter for L2-loss.") return parser.parse_args() #儲存check point函式 def save(saver, sess, logdir, step): '''Save weights. Args: saver: TensorFlow Saver object. sess: TensorFlow session. logdir: path to the snapshots directory. step: current training step. ''' #儲存的名稱、路徑 model_name = 'model.ckpt' checkpoint_path = os.path.join(logdir, model_name) if not os.path.exists(logdir): os.makedirs(logdir) saver.save(sess, checkpoint_path, global_step=step) print('The checkpoint has been created.') #reload函式,從.ckpt檔案恢復網路引數進行訓練 def load(saver, sess, ckpt_path): '''Load trained weights. Args: saver: TensorFlow Saver object. sess: TensorFlow session. ckpt_path: path to checkpoint file with parameters. ''' saver.restore(sess, ckpt_path) print("Restored model parameters from {}".format(ckpt_path)) def main(): """Create the model and start the training.""" #使用命令列傳入引數時,解析傳入的引數 #如果不適用命令列呼叫此函式,則可以通過修改預設值來實現傳入指定引數的目的 args = get_arguments() h, w = map(int, args.input_size.split(',')) input_size = (h, w) tf.set_random_seed(args.random_seed) #使用tf的佇列向網路喂資料,下面需要初始化一個佇列 #首先建立多執行緒 # Create queue coordinator. coord = tf.train.Coordinator() # Load reader. with tf.name_scope("create_inputs"): reader = ImageReader( args.data_dir, args.data_list, input_size, args.random_scale, args.random_mirror, args.ignore_label, IMG_MEAN, coord) #佇列的輸出是image_batch, label_batch image_batch, label_batch = reader.dequeue(args.batch_size) # Create network. net = DeepLabResNetModel({'data': image_batch}, is_training=args.is_training, num_classes=args.num_classes) # For a small batch size, it is better to keep # the statistics of the BN layers (running means and variances) # frozen, and to not update the values provided by the pre-trained model. # If is_training=True, the statistics will be updated during the training. # Note that is_training=False still updates BN parameters gamma (scale) and beta (offset) # if they are presented in var_list of the optimiser definition. # Predictions. raw_output = net.layers['fc1_voc12'] #確定網路中的引數,哪些需要被訓練、哪些不用 # Which variables to load. Running means and variances are not trainable, # thus all_variables() should be restored. restore_var = [v for v in tf.global_variables() if 'fc' not in v.name or not args.not_restore_last] all_trainable = [v for v in tf.trainable_variables() if 'beta' not in v.name and 'gamma' not in v.name] fc_trainable = [v for v in all_trainable if 'fc' in v.name] #被訓練的引數中,卷積引數的學習率為lr,全連線層的w引數學習率為10*lr,全連線層的b引數學習率為20×lr conv_trainable = [v for v in all_trainable if 'fc' not in v.name] # lr * 1.0 fc_w_trainable = [v for v in fc_trainable if 'weights' in v.name] # lr * 10.0 fc_b_trainable = [v for v in fc_trainable if 'biases' in v.name] # lr * 20.0 assert(len(all_trainable) == len(fc_trainable) + len(conv_trainable)) assert(len(fc_trainable) == len(fc_w_trainable) + len(fc_b_trainable)) # Predictions: ignoring all predictions with labels greater or equal than n_classes #將網路的輸出reshape成[-1, args.num_classes] raw_prediction = tf.reshape(raw_output, [-1, args.num_classes]) #由於我們使用的是sparse_softmax_cross_entropy_with_logits函式, #所以將label的維度修改成 [batch_size, h, w],也就是說去掉channel這個維度, #如果使用的是softmax_cross_entropy_with_logits,需要將one-hot位置設定為true,以便將label轉換為 [batch_size, h, w,num_classes] label_proc = prepare_label(label_batch, tf.stack(raw_output.get_shape()[1:3]), num_classes=args.num_classes, one_hot=False) # [batch_size, h, w] #變為與raw_prediction同樣的形狀 raw_gt = tf.reshape(label_proc, [-1,]) #去掉label中超過num_classes的值 indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)), 1) gt = tf.cast(tf.gather(raw_gt, indices), tf.int32) prediction = tf.gather(raw_prediction, indices) # Pixel-wise softmax loss. loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction, labels=gt) l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name] reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses) # Processed predictions: for visualisation. raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(image_batch)[1:3,]) raw_output_up = tf.argmax(raw_output_up, dimension=3) pred = tf.expand_dims(raw_output_up, dim=3) # Image summary. images_summary = tf.py_func(inv_preprocess, [image_batch, args.save_num_images, IMG_MEAN], tf.uint8) labels_summary = tf.py_func(decode_labels, [label_batch, args.save_num_images, args.num_classes], tf.uint8) preds_summary = tf.py_func(decode_labels, [pred, args.save_num_images, args.num_classes], tf.uint8) total_summary = tf.summary.image('images', tf.concat(axis=2, values=[images_summary, labels_summary, preds_summary]), max_outputs=args.save_num_images) # Concatenate row-wise. summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) # Define loss and optimisation parameters. base_lr = tf.constant(args.learning_rate) step_ph = tf.placeholder(dtype=tf.float32, shape=()) #設定學習率遞減 #我覺得這段程式獲益匪淺,因為以前對學習率的設定都是一個 tf.train.MomentumOptimizer設個初值就完事兒了 #作者對不同的引數的學習率進行了不同的設計,還設計了學習率的遞減策略,及如何把遞減策略應用到網路中去 learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power)) opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum) opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum) opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0, args.momentum) grads = tf.gradients(reduced_loss, conv_trainable + fc_w_trainable + fc_b_trainable) grads_conv = grads[:len(conv_trainable)] grads_fc_w = grads[len(conv_trainable) : (len(conv_trainable) + len(fc_w_trainable))] grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):] train_op_conv = opt_conv.apply_gradients(zip(grads_conv, conv_trainable)) train_op_fc_w = opt_fc_w.apply_gradients(zip(grads_fc_w, fc_w_trainable)) train_op_fc_b = opt_fc_b.apply_gradients(zip(grads_fc_b, fc_b_trainable)) train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) # Saver for storing checkpoints of the model. saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # Load variables if the checkpoint is provided. if args.restore_from is not None: loader = tf.train.Saver(var_list=restore_var) load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Iterate over training steps. for step in range(args.num_steps): start_time = time.time() feed_dict = { step_ph : step } if step % args.save_pred_every == 0: loss_value, images, labels, preds, summary, _ = sess.run([reduced_loss, image_batch, label_batch, pred, total_summary, train_op], feed_dict=feed_dict) summary_writer.add_summary(summary, step) save(saver, sess, args.snapshot_dir, step) else: loss_value, _ = sess.run([reduced_loss, train_op], feed_dict=feed_dict) duration = time.time() - start_time print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration)) coord.request_stop() coord.join(threads) #程式入口 if __name__ == '__main__': main()
三、image_reader.py 主要作用是初始化一個tensorflow佇列,向網路喂資料。
import os import numpy as np import tensorflow as tf def image_scaling(img, label): """ Randomly scales the images between 0.5 to 1.5 times the original size. Args: img: Training image to scale. label: Segmentation mask to scale. """ scale = tf.random_uniform([1], minval=0.5, maxval=1.5, dtype=tf.float32, seed=None) h_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[0]), scale)) w_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[1]), scale)) new_shape = tf.squeeze(tf.stack([h_new, w_new]), squeeze_dims=[1]) img = tf.image.resize_images(img, new_shape) label = tf.image.resize_nearest_neighbor(tf.expand_dims(label, 0), new_shape) label = tf.squeeze(label, squeeze_dims=[0]) return img, label def image_mirroring(img, label): """ Randomly mirrors the images. Args: img: Training image to mirror. label: Segmentation mask to mirror. """ distort_left_right_random = tf.random_uniform([1], 0, 1.0, dtype=tf.float32)[0] mirror = tf.less(tf.stack([1.0, distort_left_right_random, 1.0]), 0.5) mirror = tf.boolean_mask([0, 1, 2], mirror) img = tf.reverse(img, mirror) label = tf.reverse(label, mirror) return img, label def random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_label=255): """ Randomly crop and pads the input images. Args: image: Training image to crop/ pad. label: Segmentation mask to crop/ pad. crop_h: Height of cropped segment. crop_w: Width of cropped segment. ignore_label: Label to ignore during the training. """ label = tf.cast(label, dtype=tf.float32) label = label - ignore_label # Needs to be subtracted and later added due to 0 padding. combined = tf.concat(axis=2, values=[image, label]) image_shape = tf.shape(image) combined_pad = tf.image.pad_to_bounding_box(combined, 0, 0, tf.maximum(crop_h, image_shape[0]), tf.maximum(crop_w, image_shape[1])) last_image_dim = tf.shape(image)[-1] last_label_dim = tf.shape(label)[-1] combined_crop = tf.random_crop(combined_pad, [crop_h,crop_w,4]) img_crop = combined_crop[:, :, :last_image_dim] label_crop = combined_crop[:, :, last_image_dim:] label_crop = label_crop + ignore_label label_crop = tf.cast(label_crop, dtype=tf.uint8) # Set static shape so that tensorflow knows shape at compile time. img_crop.set_shape((crop_h, crop_w, 3)) label_crop.set_shape((crop_h,crop_w, 1)) return img_crop, label_crop def read_labeled_image_list(data_dir, data_list): """Reads txt file containing paths to images and ground truth masks. Args: data_dir: path to the directory with images and masks. data_list: path to the file with lines of the form '/path/to/image /path/to/mask'. Returns: Two lists with all file names for images and masks, respectively. """ f = open(data_list, 'r') images = [] masks = [] for line in f: try: image, mask = line.strip("\n").split(' ') except ValueError: # Adhoc for test. image = mask = line.strip("\n") images.append(data_dir + image) masks.append(data_dir + mask) return images, masks def read_images_from_disk(input_queue, input_size, random_scale, random_mirror, ignore_label, img_mean): # optional pre-processing arguments """Read one image and its corresponding mask with optional pre-processing. Args: input_queue: tf queue with paths to the image and its mask. input_size: a tuple with (height, width) values. If not given, return images of original size. random_scale: whether to randomly scale the images prior to random crop. random_mirror: whether to randomly mirror the images prior to random crop. ignore_label: index of label to ignore during the training. img_mean: vector of mean colour values. Returns: Two tensors: the decoded image and its mask. """ img_contents = tf.read_file(input_queue[0]) label_contents = tf.read_file(input_queue[1]) img = tf.image.decode_jpeg(img_contents, channels=3) img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img) img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32) # Extract mean. img -= img_mean label = tf.image.decode_png(label_contents, channels=1) if input_size is not None: h, w = input_size # Randomly scale the images and labels. if random_scale: img, label = image_scaling(img, label) # Randomly mirror the images and labels. if random_mirror: img, label = image_mirroring(img, label) # Randomly crops the images and labels. img, label = random_crop_and_pad_image_and_labels(img, label, h, w, ignore_label) return img, label class ImageReader(object): '''Generic ImageReader which reads images and corresponding segmentation masks from the disk, and enqueues them into a TensorFlow queue. ''' def __init__(self, data_dir, data_list, input_size, random_scale, random_mirror, ignore_label, img_mean, coord): '''Initialise an ImageReader. Args: data_dir: path to the directory with images and masks. data_list: path to the file with lines of the form '/path/to/image /path/to/mask'. input_size: a tuple with (height, width) values, to which all the images will be resized. random_scale: whether to randomly scale the images prior to random crop. random_mirror: whether to randomly mirror the images prior to random crop. ignore_label: index of label to ignore during the training. img_mean: vector of mean colour values. coord: TensorFlow queue coordinator. ''' self.data_dir = data_dir self.data_list = data_list self.input_size = input_size self.coord = coord #self.image_list, self.label_list是列表,表中包含了所有的image和label的列表 self.image_list, self.label_list = read_labeled_image_list(self.data_dir, self.data_list) #self.image_list,self.label_list轉為tensor,以便加入圖中流動起來 self.images = tf.convert_to_tensor(self.image_list, dtype=tf.string) self.labels = tf.convert_to_tensor(self.label_list, dtype=tf.string) #注意傳入的引數要寫成列表形式 #產生一個佇列每次隨機產生一張圖片地址 self.queue = tf.train.slice_input_producer([self.images, self.labels], shuffle=input_size is not None) # not shuffling if it is val #從指定地址讀取圖片 self.image, self.label = read_images_from_disk(self.queue, self.input_size, random_scale, random_mirror, ignore_label, img_mean) def dequeue(self, num_elements): '''Pack images and labels into a batch. Args: num_elements: the batch size. Returns: Two tensors of size (batch_size, h, w, {3, 1}) for images and masks.''' image_batch, label_batch = tf.train.batch([self.image, self.label], num_elements) return image_batch, label_batch