1. 程式人生 > >tensorflow-deeplab-resnet 原理及程式碼詳解

tensorflow-deeplab-resnet 原理及程式碼詳解

前言: 程式碼的model.py,network.py是建立深度學習網路的部分,這部分程式碼風格與Faster-RCNN_TF那個程式的風格非常相似,也很簡單,不再多做介紹。這裡主要介紹train.py、image_reader.py其他還有inference.py、utils.py、fine_tune.py就不做介紹了,比較簡單。

一、網路結構: 對程式碼的Network.py稍作修改,使得它打印出各層的網路的輸出如下:

因為單張1080顯示卡現存有限,我使用的batch設定為4,輸入圖片是384*384大小的,但是第一個卷積的步長設定為了2
conv1 (4, 192, 192, 64)
bn_conv1 (4, 192, 192, 64)
pool1 (4, 192, 192, 64)
res2a_branch1 (4, 96, 96, 256)
bn2a_branch1 (4, 96, 96, 256)
res2a_branch2a (4, 96, 96, 64)
bn2a_branch2a (4, 96, 96, 64)
res2a_branch2b (4, 96, 96, 64)
bn2a_branch2b (4, 96, 96, 64)
res2a_branch2c (4, 96, 96, 256)
bn2a_branch2c (4, 96, 96, 256)
res2a_relu (4, 96, 96, 256)
res2b_branch2a (4, 96, 96, 64)
bn2b_branch2a (4, 96, 96, 64)
res2b_branch2b (4, 96, 96, 64)
bn2b_branch2b (4, 96, 96, 64)
res2b_branch2c (4, 96, 96, 256)
bn2b_branch2c (4, 96, 96, 256)
res2b_relu (4, 96, 96, 256)
res2c_branch2a (4, 96, 96, 64)
bn2c_branch2a (4, 96, 96, 64)
res2c_branch2b (4, 96, 96, 64)
bn2c_branch2b (4, 96, 96, 64)
res2c_branch2c (4, 96, 96, 256)
bn2c_branch2c (4, 96, 96, 256)
res2c_relu (4, 96, 96, 256)
res3a_branch1 (4, 48, 48, 512)
bn3a_branch1 (4, 48, 48, 512)
res3a_branch2a (4, 48, 48, 128)
bn3a_branch2a (4, 48, 48, 128)
res3a_branch2b (4, 48, 48, 128)
bn3a_branch2b (4, 48, 48, 128)
res3a_branch2c (4, 48, 48, 512)
bn3a_branch2c (4, 48, 48, 512)
res3a_relu (4, 48, 48, 512)
res3b1_branch2a (4, 48, 48, 128)
bn3b1_branch2a (4, 48, 48, 128)
res3b1_branch2b (4, 48, 48, 128)
bn3b1_branch2b (4, 48, 48, 128)
res3b1_branch2c (4, 48, 48, 512)
bn3b1_branch2c (4, 48, 48, 512)
res3b1_relu (4, 48, 48, 512)
res3b2_branch2a (4, 48, 48, 128)
bn3b2_branch2a (4, 48, 48, 128)
res3b2_branch2b (4, 48, 48, 128)
bn3b2_branch2b (4, 48, 48, 128)
res3b2_branch2c (4, 48, 48, 512)
bn3b2_branch2c (4, 48, 48, 512)
res3b2_relu (4, 48, 48, 512)
res3b3_branch2a (4, 48, 48, 128)
bn3b3_branch2a (4, 48, 48, 128)
res3b3_branch2b (4, 48, 48, 128)
bn3b3_branch2b (4, 48, 48, 128)
res3b3_branch2c (4, 48, 48, 512)
bn3b3_branch2c (4, 48, 48, 512)
res3b3_relu (4, 48, 48, 512)
res4a_branch1 (4, 48, 48, 1024)
bn4a_branch1 (4, 48, 48, 1024)
res4a_branch2a (4, 48, 48, 256)
bn4a_branch2a (4, 48, 48, 256)
res4a_branch2b (4, 48, 48, 256)
bn4a_branch2b (4, 48, 48, 256)
res4a_branch2c (4, 48, 48, 1024)
bn4a_branch2c (4, 48, 48, 1024)
res4a_relu (4, 48, 48, 1024)
res4b1_branch2a (4, 48, 48, 256)
bn4b1_branch2a (4, 48, 48, 256)
res4b1_branch2b (4, 48, 48, 256)
bn4b1_branch2b (4, 48, 48, 256)
res4b1_branch2c (4, 48, 48, 1024)
bn4b1_branch2c (4, 48, 48, 1024)
res4b1_relu (4, 48, 48, 1024)
res4b2_branch2a (4, 48, 48, 256)
bn4b2_branch2a (4, 48, 48, 256)
res4b2_branch2b (4, 48, 48, 256)
bn4b2_branch2b (4, 48, 48, 256)
res4b2_branch2c (4, 48, 48, 1024)
bn4b2_branch2c (4, 48, 48, 1024)
res4b2_relu (4, 48, 48, 1024)
res4b3_branch2a (4, 48, 48, 256)
bn4b3_branch2a (4, 48, 48, 256)
res4b3_branch2b (4, 48, 48, 256)
bn4b3_branch2b (4, 48, 48, 256)
res4b3_branch2c (4, 48, 48, 1024)
bn4b3_branch2c (4, 48, 48, 1024)
res4b3_relu (4, 48, 48, 1024)
res4b4_branch2a (4, 48, 48, 256)
bn4b4_branch2a (4, 48, 48, 256)
res4b4_branch2b (4, 48, 48, 256)
bn4b4_branch2b (4, 48, 48, 256)
res4b4_branch2c (4, 48, 48, 1024)
bn4b4_branch2c (4, 48, 48, 1024)
res4b4_relu (4, 48, 48, 1024)
res4b5_branch2a (4, 48, 48, 256)
bn4b5_branch2a (4, 48, 48, 256)
res4b5_branch2b (4, 48, 48, 256)
bn4b5_branch2b (4, 48, 48, 256)
res4b5_branch2c (4, 48, 48, 1024)
bn4b5_branch2c (4, 48, 48, 1024)
res4b5_relu (4, 48, 48, 1024)
res4b6_branch2a (4, 48, 48, 256)
bn4b6_branch2a (4, 48, 48, 256)
res4b6_branch2b (4, 48, 48, 256)
bn4b6_branch2b (4, 48, 48, 256)
res4b6_branch2c (4, 48, 48, 1024)
bn4b6_branch2c (4, 48, 48, 1024)
res4b6_relu (4, 48, 48, 1024)
res4b7_branch2a (4, 48, 48, 256)
bn4b7_branch2a (4, 48, 48, 256)
res4b7_branch2b (4, 48, 48, 256)
bn4b7_branch2b (4, 48, 48, 256)
res4b7_branch2c (4, 48, 48, 1024)
bn4b7_branch2c (4, 48, 48, 1024)
res4b7_relu (4, 48, 48, 1024)
res4b8_branch2a (4, 48, 48, 256)
bn4b8_branch2a (4, 48, 48, 256)
res4b8_branch2b (4, 48, 48, 256)
bn4b8_branch2b (4, 48, 48, 256)
res4b8_branch2c (4, 48, 48, 1024)
bn4b8_branch2c (4, 48, 48, 1024)
res4b8_relu (4, 48, 48, 1024)
res4b9_branch2a (4, 48, 48, 256)
bn4b9_branch2a (4, 48, 48, 256)
res4b9_branch2b (4, 48, 48, 256)
bn4b9_branch2b (4, 48, 48, 256)
res4b9_branch2c (4, 48, 48, 1024)
bn4b9_branch2c (4, 48, 48, 1024)
res4b9_relu (4, 48, 48, 1024)
res4b10_branch2a (4, 48, 48, 256)
bn4b10_branch2a (4, 48, 48, 256)
res4b10_branch2b (4, 48, 48, 256)
bn4b10_branch2b (4, 48, 48, 256)
res4b10_branch2c (4, 48, 48, 1024)
bn4b10_branch2c (4, 48, 48, 1024)
res4b10_relu (4, 48, 48, 1024)
res4b11_branch2a (4, 48, 48, 256)
bn4b11_branch2a (4, 48, 48, 256)
res4b11_branch2b (4, 48, 48, 256)
bn4b11_branch2b (4, 48, 48, 256)
res4b11_branch2c (4, 48, 48, 1024)
bn4b11_branch2c (4, 48, 48, 1024)
res4b11_relu (4, 48, 48, 1024)
res4b12_branch2a (4, 48, 48, 256)
bn4b12_branch2a (4, 48, 48, 256)
res4b12_branch2b (4, 48, 48, 256)
bn4b12_branch2b (4, 48, 48, 256)
res4b12_branch2c (4, 48, 48, 1024)
bn4b12_branch2c (4, 48, 48, 1024)
res4b12_relu (4, 48, 48, 1024)
res4b13_branch2a (4, 48, 48, 256)
bn4b13_branch2a (4, 48, 48, 256)
res4b13_branch2b (4, 48, 48, 256)
bn4b13_branch2b (4, 48, 48, 256)
res4b13_branch2c (4, 48, 48, 1024)
bn4b13_branch2c (4, 48, 48, 1024)
res4b13_relu (4, 48, 48, 1024)
res4b14_branch2a (4, 48, 48, 256)
bn4b14_branch2a (4, 48, 48, 256)
res4b14_branch2b (4, 48, 48, 256)
bn4b14_branch2b (4, 48, 48, 256)
res4b14_branch2c (4, 48, 48, 1024)
bn4b14_branch2c (4, 48, 48, 1024)
res4b14_relu (4, 48, 48, 1024)
res4b15_branch2a (4, 48, 48, 256)
bn4b15_branch2a (4, 48, 48, 256)
res4b15_branch2b (4, 48, 48, 256)
bn4b15_branch2b (4, 48, 48, 256)
res4b15_branch2c (4, 48, 48, 1024)
bn4b15_branch2c (4, 48, 48, 1024)
res4b15_relu (4, 48, 48, 1024)
res4b16_branch2a (4, 48, 48, 256)
bn4b16_branch2a (4, 48, 48, 256)
res4b16_branch2b (4, 48, 48, 256)
bn4b16_branch2b (4, 48, 48, 256)
res4b16_branch2c (4, 48, 48, 1024)
bn4b16_branch2c (4, 48, 48, 1024)
res4b16_relu (4, 48, 48, 1024)
res4b17_branch2a (4, 48, 48, 256)
bn4b17_branch2a (4, 48, 48, 256)
res4b17_branch2b (4, 48, 48, 256)
bn4b17_branch2b (4, 48, 48, 256)
res4b17_branch2c (4, 48, 48, 1024)
bn4b17_branch2c (4, 48, 48, 1024)
res4b17_relu (4, 48, 48, 1024)
res4b18_branch2a (4, 48, 48, 256)
bn4b18_branch2a (4, 48, 48, 256)
res4b18_branch2b (4, 48, 48, 256)
bn4b18_branch2b (4, 48, 48, 256)
res4b18_branch2c (4, 48, 48, 1024)
bn4b18_branch2c (4, 48, 48, 1024)
res4b18_relu (4, 48, 48, 1024)
res4b19_branch2a (4, 48, 48, 256)
bn4b19_branch2a (4, 48, 48, 256)
res4b19_branch2b (4, 48, 48, 256)
bn4b19_branch2b (4, 48, 48, 256)
res4b19_branch2c (4, 48, 48, 1024)
bn4b19_branch2c (4, 48, 48, 1024)
res4b19_relu (4, 48, 48, 1024)
res4b20_branch2a (4, 48, 48, 256)
bn4b20_branch2a (4, 48, 48, 256)
res4b20_branch2b (4, 48, 48, 256)
bn4b20_branch2b (4, 48, 48, 256)
res4b20_branch2c (4, 48, 48, 1024)
bn4b20_branch2c (4, 48, 48, 1024)
res4b20_relu (4, 48, 48, 1024)
res4b21_branch2a (4, 48, 48, 256)
bn4b21_branch2a (4, 48, 48, 256)
res4b21_branch2b (4, 48, 48, 256)
bn4b21_branch2b (4, 48, 48, 256)
res4b21_branch2c (4, 48, 48, 1024)
bn4b21_branch2c (4, 48, 48, 1024)
res4b21_relu (4, 48, 48, 1024)
res4b22_branch2a (4, 48, 48, 256)
bn4b22_branch2a (4, 48, 48, 256)
res4b22_branch2b (4, 48, 48, 256)
bn4b22_branch2b (4, 48, 48, 256)
res4b22_branch2c (4, 48, 48, 1024)
bn4b22_branch2c (4, 48, 48, 1024)
res4b22_relu (4, 48, 48, 1024)
res5a_branch1 (4, 48, 48, 2048)
bn5a_branch1 (4, 48, 48, 2048)
res5a_branch2a (4, 48, 48, 512)
bn5a_branch2a (4, 48, 48, 512)
res5a_branch2b (4, 48, 48, 512)
bn5a_branch2b (4, 48, 48, 512)
res5a_branch2c (4, 48, 48, 2048)
bn5a_branch2c (4, 48, 48, 2048)
res5a_relu (4, 48, 48, 2048)
res5b_branch2a (4, 48, 48, 512)
bn5b_branch2a (4, 48, 48, 512)
res5b_branch2b (4, 48, 48, 512)
bn5b_branch2b (4, 48, 48, 512)
res5b_branch2c (4, 48, 48, 2048)
bn5b_branch2c (4, 48, 48, 2048)
res5b_relu (4, 48, 48, 2048)
res5c_branch2a (4, 48, 48, 512)
bn5c_branch2a (4, 48, 48, 512)
res5c_branch2b (4, 48, 48, 512)
bn5c_branch2b (4, 48, 48, 512)
res5c_branch2c (4, 48, 48, 2048)
bn5c_branch2c (4, 48, 48, 2048)
res5c_relu (4, 48, 48, 2048)
fc1_voc12_c0 (4, 48, 48, 2)
fc1_voc12_c1 (4, 48, 48, 2)
fc1_voc12_c2 (4, 48, 48, 2)
fc1_voc12_c3 (4, 48, 48, 2)

下面仔細介紹下程式碼的具體實現:



from __future__ import print_function

import argparse
from datetime import datetime
import os
import sys
import time

import tensorflow as tf
import numpy as np

from deeplab_resnet import DeepLabResNetModel, ImageReader, decode_labels, inv_preprocess, prepare_label

IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)

BATCH_SIZE = 10
DATA_DIRECTORY = '/home/VOCdevkit'
DATA_LIST_PATH = './dataset/train.txt'
IGNORE_LABEL = 255
INPUT_SIZE = '321,321'
LEARNING_RATE = 2.5e-4
MOMENTUM = 0.9
NUM_CLASSES = 21
NUM_STEPS = 20001
POWER = 0.9
RANDOM_SEED = 1234
RESTORE_FROM = './deeplab_resnet.ckpt'
SAVE_NUM_IMAGES = 2
SAVE_PRED_EVERY = 1000
SNAPSHOT_DIR = './snapshots/'
WEIGHT_DECAY = 0.0005


def get_arguments():
    """Parse all the arguments provided from the CLI.
    
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
    parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
                        help="Number of images sent to the network in one step.")
    parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
                        help="Path to the directory containing the PASCAL VOC dataset.")
    parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH,
                        help="Path to the file listing the images in the dataset.")
    parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
                        help="The index of the label to ignore during the training.")
    parser.add_argument("--input-size", type=str, default=INPUT_SIZE,
                        help="Comma-separated string with height and width of images.")
    parser.add_argument("--is-training", action="store_true",
                        help="Whether to updates the running means and variances during the training.")
    parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,
                        help="Base learning rate for training with polynomial decay.")
    parser.add_argument("--momentum", type=float, default=MOMENTUM,
                        help="Momentum component of the optimiser.")
    parser.add_argument("--not-restore-last", action="store_true",
                        help="Whether to not restore last (FC) layers.")
    parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
                        help="Number of classes to predict (including background).")
    parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
                        help="Number of training steps.")
    parser.add_argument("--power", type=float, default=POWER,
                        help="Decay parameter to compute the learning rate.")
    parser.add_argument("--random-mirror", action="store_true",
                        help="Whether to randomly mirror the inputs during the training.")
    parser.add_argument("--random-scale", action="store_true",
                        help="Whether to randomly scale the inputs during the training.")
    parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,
                        help="Random seed to have reproducible results.")
    parser.add_argument("--restore-from", type=str, default=RESTORE_FROM,
                        help="Where restore model parameters from.")
    parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES,
                        help="How many images to save.")
    parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY,
                        help="Save summaries and checkpoint every often.")
    parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
                        help="Where to save snapshots of the model.")
    parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,
                        help="Regularisation parameter for L2-loss.")
    return parser.parse_args()
    
#儲存check point函式
def save(saver, sess, logdir, step):
   '''Save weights.
   
   Args:
     saver: TensorFlow Saver object.
     sess: TensorFlow session.
     logdir: path to the snapshots directory.
     step: current training step.
   '''
  #儲存的名稱、路徑
   model_name = 'model.ckpt'
   checkpoint_path = os.path.join(logdir, model_name)
    
   if not os.path.exists(logdir):
      os.makedirs(logdir)
   saver.save(sess, checkpoint_path, global_step=step)
   print('The checkpoint has been created.')

#reload函式,從.ckpt檔案恢復網路引數進行訓練
def load(saver, sess, ckpt_path):
    '''Load trained weights.
    
    Args:
      saver: TensorFlow Saver object.
      sess: TensorFlow session.
      ckpt_path: path to checkpoint file with parameters.
    ''' 
    saver.restore(sess, ckpt_path)
    print("Restored model parameters from {}".format(ckpt_path))

def main():
    """Create the model and start the training."""
    #使用命令列傳入引數時,解析傳入的引數
    #如果不適用命令列呼叫此函式,則可以通過修改預設值來實現傳入指定引數的目的
    args = get_arguments()
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    tf.set_random_seed(args.random_seed)
    
    #使用tf的佇列向網路喂資料,下面需要初始化一個佇列
    #首先建立多執行緒
    # Create queue coordinator.
    coord = tf.train.Coordinator()
    
    # Load reader.
    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            args.data_dir,
            args.data_list,
            input_size,
            args.random_scale,
            args.random_mirror,
            args.ignore_label,
            IMG_MEAN,
            coord)
        #佇列的輸出是image_batch, label_batch
        image_batch, label_batch = reader.dequeue(args.batch_size)
    
    # Create network.
    net = DeepLabResNetModel({'data': image_batch}, is_training=args.is_training, num_classes=args.num_classes)
    # For a small batch size, it is better to keep 
    # the statistics of the BN layers (running means and variances)
    # frozen, and to not update the values provided by the pre-trained model. 
    # If is_training=True, the statistics will be updated during the training.
    # Note that is_training=False still updates BN parameters gamma (scale) and beta (offset)
    # if they are presented in var_list of the optimiser definition.

    # Predictions.
    raw_output = net.layers['fc1_voc12']
    #確定網路中的引數,哪些需要被訓練、哪些不用
    # Which variables to load. Running means and variances are not trainable,
    # thus all_variables() should be restored.
    restore_var = [v for v in tf.global_variables() if 'fc' not in v.name or not args.not_restore_last]
    all_trainable = [v for v in tf.trainable_variables() if 'beta' not in v.name and 'gamma' not in v.name]
    fc_trainable = [v for v in all_trainable if 'fc' in v.name]
    
    #被訓練的引數中,卷積引數的學習率為lr,全連線層的w引數學習率為10*lr,全連線層的b引數學習率為20×lr
    conv_trainable = [v for v in all_trainable if 'fc' not in v.name] # lr * 1.0
    fc_w_trainable = [v for v in fc_trainable if 'weights' in v.name] # lr * 10.0
    fc_b_trainable = [v for v in fc_trainable if 'biases' in v.name] # lr * 20.0
    assert(len(all_trainable) == len(fc_trainable) + len(conv_trainable))
    assert(len(fc_trainable) == len(fc_w_trainable) + len(fc_b_trainable))
    
    
    # Predictions: ignoring all predictions with labels greater or equal than n_classes
    #將網路的輸出reshape成[-1, args.num_classes]
    raw_prediction = tf.reshape(raw_output, [-1, args.num_classes])
    #由於我們使用的是sparse_softmax_cross_entropy_with_logits函式,
    #所以將label的維度修改成 [batch_size, h, w],也就是說去掉channel這個維度,
    #如果使用的是softmax_cross_entropy_with_logits,需要將one-hot位置設定為true,以便將label轉換為 [batch_size, h, w,num_classes]
    label_proc = prepare_label(label_batch, tf.stack(raw_output.get_shape()[1:3]), num_classes=args.num_classes, one_hot=False) # [batch_size, h, w]
    #變為與raw_prediction同樣的形狀
    raw_gt = tf.reshape(label_proc, [-1,])
    #去掉label中超過num_classes的值
    indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)), 1)
    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
    prediction = tf.gather(raw_prediction, indices)
                                                  
                                                  
    # Pixel-wise softmax loss.
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction, labels=gt)
    l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
    reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)
    
    # Processed predictions: for visualisation.
    raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(image_batch)[1:3,])
    raw_output_up = tf.argmax(raw_output_up, dimension=3)
    pred = tf.expand_dims(raw_output_up, dim=3)
    
    # Image summary.
    images_summary = tf.py_func(inv_preprocess, [image_batch, args.save_num_images, IMG_MEAN], tf.uint8)
    labels_summary = tf.py_func(decode_labels, [label_batch, args.save_num_images, args.num_classes], tf.uint8)
    preds_summary = tf.py_func(decode_labels, [pred, args.save_num_images, args.num_classes], tf.uint8)
    
    total_summary = tf.summary.image('images', 
                                     tf.concat(axis=2, values=[images_summary, labels_summary, preds_summary]), 
                                     max_outputs=args.save_num_images) # Concatenate row-wise.
    summary_writer = tf.summary.FileWriter(args.snapshot_dir,
                                           graph=tf.get_default_graph())
   
    # Define loss and optimisation parameters.
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    #設定學習率遞減
    #我覺得這段程式獲益匪淺,因為以前對學習率的設定都是一個 tf.train.MomentumOptimizer設個初值就完事兒了
    #作者對不同的引數的學習率進行了不同的設計,還設計了學習率的遞減策略,及如何把遞減策略應用到網路中去
    learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
    
    opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
    opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum)
    opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0, args.momentum)

    grads = tf.gradients(reduced_loss, conv_trainable + fc_w_trainable + fc_b_trainable)
    grads_conv = grads[:len(conv_trainable)]
    grads_fc_w = grads[len(conv_trainable) : (len(conv_trainable) + len(fc_w_trainable))]
    grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):]

    train_op_conv = opt_conv.apply_gradients(zip(grads_conv, conv_trainable))
    train_op_fc_w = opt_fc_w.apply_gradients(zip(grads_fc_w, fc_w_trainable))
    train_op_fc_b = opt_fc_b.apply_gradients(zip(grads_fc_b, fc_b_trainable))

    train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b)
    
    
    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    
    sess.run(init)
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
    
    # Load variables if the checkpoint is provided.
    if args.restore_from is not None:
        loader = tf.train.Saver(var_list=restore_var)
        load(loader, sess, args.restore_from)
    
    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()
        feed_dict = { step_ph : step }
        
        if step % args.save_pred_every == 0:
            loss_value, images, labels, preds, summary, _ = sess.run([reduced_loss, image_batch, label_batch, pred, total_summary, train_op], feed_dict=feed_dict)
            summary_writer.add_summary(summary, step)
            save(saver, sess, args.snapshot_dir, step)
        else:
            loss_value, _ = sess.run([reduced_loss, train_op], feed_dict=feed_dict)
        duration = time.time() - start_time
        print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))
    coord.request_stop()
    coord.join(threads)
    
#程式入口
if __name__ == '__main__':
    main()

三、image_reader.py 主要作用是初始化一個tensorflow佇列,向網路喂資料。

import os

import numpy as np
import tensorflow as tf

def image_scaling(img, label):
    """
    Randomly scales the images between 0.5 to 1.5 times the original size.
    Args:
      img: Training image to scale.
      label: Segmentation mask to scale.
    """
    
    scale = tf.random_uniform([1], minval=0.5, maxval=1.5, dtype=tf.float32, seed=None)
    h_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[0]), scale))
    w_new = tf.to_int32(tf.multiply(tf.to_float(tf.shape(img)[1]), scale))
    new_shape = tf.squeeze(tf.stack([h_new, w_new]), squeeze_dims=[1])
    img = tf.image.resize_images(img, new_shape)
    label = tf.image.resize_nearest_neighbor(tf.expand_dims(label, 0), new_shape)
    label = tf.squeeze(label, squeeze_dims=[0])
   
    return img, label

def image_mirroring(img, label):
    """
    Randomly mirrors the images.
    Args:
      img: Training image to mirror.
      label: Segmentation mask to mirror.
    """
    
    distort_left_right_random = tf.random_uniform([1], 0, 1.0, dtype=tf.float32)[0]
    mirror = tf.less(tf.stack([1.0, distort_left_right_random, 1.0]), 0.5)
    mirror = tf.boolean_mask([0, 1, 2], mirror)
    img = tf.reverse(img, mirror)
    label = tf.reverse(label, mirror)
    return img, label

def random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_label=255):
    """
    Randomly crop and pads the input images.
    Args:
      image: Training image to crop/ pad.
      label: Segmentation mask to crop/ pad.
      crop_h: Height of cropped segment.
      crop_w: Width of cropped segment.
      ignore_label: Label to ignore during the training.
    """

    label = tf.cast(label, dtype=tf.float32)
    label = label - ignore_label # Needs to be subtracted and later added due to 0 padding.
    combined = tf.concat(axis=2, values=[image, label]) 
    image_shape = tf.shape(image)
    combined_pad = tf.image.pad_to_bounding_box(combined, 0, 0, tf.maximum(crop_h, image_shape[0]), tf.maximum(crop_w, image_shape[1]))
    
    last_image_dim = tf.shape(image)[-1]
    last_label_dim = tf.shape(label)[-1]
    combined_crop = tf.random_crop(combined_pad, [crop_h,crop_w,4])
    img_crop = combined_crop[:, :, :last_image_dim]
    label_crop = combined_crop[:, :, last_image_dim:]
    label_crop = label_crop + ignore_label
    label_crop = tf.cast(label_crop, dtype=tf.uint8)
    
    # Set static shape so that tensorflow knows shape at compile time. 
    img_crop.set_shape((crop_h, crop_w, 3))
    label_crop.set_shape((crop_h,crop_w, 1))
    return img_crop, label_crop  

def read_labeled_image_list(data_dir, data_list):
    """Reads txt file containing paths to images and ground truth masks.
    
    Args:
      data_dir: path to the directory with images and masks.
      data_list: path to the file with lines of the form '/path/to/image /path/to/mask'.
       
    Returns:
      Two lists with all file names for images and masks, respectively.
    """
    f = open(data_list, 'r')
    images = []
    masks = []
    for line in f:
        try:
            image, mask = line.strip("\n").split(' ')
        except ValueError: # Adhoc for test.
            image = mask = line.strip("\n")
        images.append(data_dir + image)
        masks.append(data_dir + mask)
    return images, masks

def read_images_from_disk(input_queue, input_size, random_scale, random_mirror, ignore_label, img_mean): # optional pre-processing arguments
    """Read one image and its corresponding mask with optional pre-processing.
    
    Args:
      input_queue: tf queue with paths to the image and its mask.
      input_size: a tuple with (height, width) values.
                  If not given, return images of original size.
      random_scale: whether to randomly scale the images prior
                    to random crop.
      random_mirror: whether to randomly mirror the images prior
                    to random crop.
      ignore_label: index of label to ignore during the training.
      img_mean: vector of mean colour values.
      
    Returns:
      Two tensors: the decoded image and its mask.
    """

    img_contents = tf.read_file(input_queue[0])
    label_contents = tf.read_file(input_queue[1])
    
    img = tf.image.decode_jpeg(img_contents, channels=3)
    img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img)
    img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
    # Extract mean.
    img -= img_mean

    label = tf.image.decode_png(label_contents, channels=1)

    if input_size is not None:
        h, w = input_size

        # Randomly scale the images and labels.
        if random_scale:
            img, label = image_scaling(img, label)

        # Randomly mirror the images and labels.
        if random_mirror:
            img, label = image_mirroring(img, label)

        # Randomly crops the images and labels.
        img, label = random_crop_and_pad_image_and_labels(img, label, h, w, ignore_label)

    return img, label

class ImageReader(object):
    '''Generic ImageReader which reads images and corresponding segmentation
       masks from the disk, and enqueues them into a TensorFlow queue.
    '''

    def __init__(self, data_dir, data_list, input_size, 
                 random_scale, random_mirror, ignore_label, img_mean, coord):
        '''Initialise an ImageReader.
        
        Args:
          data_dir: path to the directory with images and masks.
          data_list: path to the file with lines of the form '/path/to/image /path/to/mask'.
          input_size: a tuple with (height, width) values, to which all the images will be resized.
          random_scale: whether to randomly scale the images prior to random crop.
          random_mirror: whether to randomly mirror the images prior to random crop.
          ignore_label: index of label to ignore during the training.
          img_mean: vector of mean colour values.
          coord: TensorFlow queue coordinator.
        '''
        self.data_dir = data_dir
        self.data_list = data_list
        self.input_size = input_size
        self.coord = coord
        
        #self.image_list, self.label_list是列表,表中包含了所有的image和label的列表
        self.image_list, self.label_list = read_labeled_image_list(self.data_dir, self.data_list)
        #self.image_list,self.label_list轉為tensor,以便加入圖中流動起來
        self.images = tf.convert_to_tensor(self.image_list, dtype=tf.string)
        self.labels = tf.convert_to_tensor(self.label_list, dtype=tf.string)
        #注意傳入的引數要寫成列表形式
        #產生一個佇列每次隨機產生一張圖片地址
        self.queue = tf.train.slice_input_producer([self.images, self.labels],
                                                   shuffle=input_size is not None) # not shuffling if it is val
        #從指定地址讀取圖片                            
        self.image, self.label = read_images_from_disk(self.queue, self.input_size, random_scale, random_mirror, ignore_label, img_mean) 

    def dequeue(self, num_elements):
        '''Pack images and labels into a batch.
        
        Args:
          num_elements: the batch size.
          
        Returns:
          Two tensors of size (batch_size, h, w, {3, 1}) for images and masks.'''
        image_batch, label_batch = tf.train.batch([self.image, self.label],
                                                  num_elements)
        return image_batch, label_batch