1. 程式人生 > 其它 >DeeplabV3+訓練自己的資料集(三)

DeeplabV3+訓練自己的資料集(三)

模型訓練及測試

一、在DeepLabv3+模型的基礎上,主要需要修改以下兩個檔案

 data_generator.py

  train_utils.py

  (1)新增資料集描述

  在datasets/data_generator.py檔案中,新增自己的資料集描述:
_CAMVID_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
    'train': 1035,
    'val': 31,},
    num_classes=3,
    ignore_label=255, )
自己的資料集共有3個classes,算上了background。由於沒有使用 ignore_label , 沒有算上ignore_label

  (2)註冊資料集

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
    'camvid':_CAMVID_INFORMATION,
    # 'mydata':_MYDATA_INFORMATION,
    }

  (3)修改train_utils.py 

  對應的utils/train_utils.py中,將210行關於 exclude_list 的設定修改,作用是在使用預訓練權重時候,不載入該 logit 層:

  

exclude_list = ['global_step','logits']
if not initialize_last_layer:
    exclude_list.extend(last_layers)

  如果想在DeepLab的基礎上fifine-tune其他資料集, 可在deeplab/train.py中修改輸入引數。

  一些選項:     使用預訓練的所有權重,設定initialize_last_layer=True     只使用網路的backbone,設定initialize_last_layer=False和     last_layers_contain_logits_only=False     使用所有的預訓練權重,除了logits。因為如果是自己的資料集,對應的classes不同(這個我們前面已經設定不載入logits),可設定initialize_last_layer=False和ast_layers_contain_logits_only=True   這裡使用的設定是:   initialize_last_layer=False #157行   last_layers_contain_logits_only=True #160行

二、網路訓練

  (1)下載預訓練模型

  下載地址:https://github.com/tensorflflow/models/blob/master/research/deeplab/g3doc/model_zoo.md  

  下載到deeplab目錄下,然後解壓:   tar -zxvf deeplabv3_cityscapes_train_2018_02_06.tar.gz   需要注意對應的解壓檔案目錄為:
/lwh/models/research/deeplab/deeplabv3_cityscapes_train

  (2)類別不平衡修正

    blackboard分割專案案例中的資料集,因為是3分類問題,其中background佔了非常大的比例,設定的     權重比例為1,3,3,     注意:權重的設定對最終的分割效能有影響。權重的設定因資料集而異。         在common.py的145行修改權重如下:   
flags.DEFINE_multi_float(
    'label_weights', [1.0,3.0,3.0],
    'A list of label weights, each element represents the weight for the label '
    'of its index, for example, label_weights = [0.1, 0.5] means the weight '
    'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all '
    'the labels have the same weight 1.0.')

  (3)訓練

    注意如下幾個引數:     train_logdir: 訓練產生的檔案存放位置     dataset_dir: 資料集的TFRecord檔案     dataset:設定為在data_generator.py檔案設定的資料集名稱          在自己的資料集上的訓練指令如下:     在目錄 ~/models/research/deeplab下執行   
python train.py   --training_number_of_steps=30000  --train_split="train"  --model_variant="xception_65" 
--atrous_rates=6 --atrous_rates=12 --atrous_rates=18 --output_stride=16 --decoder_output_stride=4
--train_crop_size=801,801 --train_batch_size=2 --dataset="camvid"
--tf_initial_checkpoint='/lwh/models/research/deeplab/deeplabv3_cityscapes_train/model.ckpt'
--train_logdir='/lwh/models/research/deeplab/exp/blackboard_train/train'
--dataset_dir='/lwh/models/research/deeplab/datasets/blackboard/tfrecord'

    設定train_crop_size原則:

    output_stride * k + 1, where k is an integer. For example, we have 321x321,513x513,801x801

  (4)模型匯出

  

python export_model.py \
    --logtostderr \
    --checkpoint_path="/lwh/models/research/deeplab/exp/blackboard_train/train/model.ckpt-30000" \
    --export_path="/lwh/models/research/deeplab/exp/blackboard_train/train/frozen_inference_graph.pb"  \
    --model_variant="xception_65"  \
    --atrous_rates=6  \
    --atrous_rates=12  \
    --atrous_rates=18   \
    --output_stride=16  \
    --decoder_output_stride=4  \
    --num_classes=3 \
    --crop_size=1080 \
    --crop_size=1920 \
    --inference_scales=1.0

  注意幾點:

  --checkpoint_path 為自己模型儲存的路徑

  --export_path 模型匯出儲存的路徑

  --num_classes=3 自己資料的類別數包含背景 

--crop_size=1080 第一個為模型要求輸入的高h

--crop_size=1920第一個為模型要求輸入的寬w

三、模型測試

  直接上程式碼

  

# !--*-- coding:utf-8 --*--

# Deeplab Demo

import os
import tarfile

from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib

import tensorflow as tf


class DeepLabModel(object):
    """
  載入 DeepLab 模型;
  推斷 Inference
  """
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 1920
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, tarball_path):
        """
    Creates and loads pretrained deeplab model.
    """
        self.graph = tf.Graph()

        graph_def = None
        graph_def = tf.GraphDef.FromString(open(tarball_path, 'rb').read())

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')
        self.sess = tf.Session(graph=self.graph)

    def run(self, image):
        """
    Runs inference on a single image.
    Args:
    image: A PIL.Image object, raw input image.
    Returns:
    resized_image: RGB image resized from original input image.
    seg_map: Segmentation map of `resized_image`.
    """
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        target_size = (1920,1080)
        resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
        print(resized_image)
        batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
                                      feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
        seg_map = batch_seg_map[0]
        return resized_image, seg_map


def create_pascal_label_colormap():
    """
  Creates a label colormap used in PASCAL VOC segmentation benchmark.
  Returns:
      A Colormap for visualizing segmentation results.
  """
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)

    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
        ind >>= 3

    return colormap


def label_to_color_image(label):
    """
  Adds color defined by the dataset colormap to the label.
  Args:
      label: A 2D array with integer type, storing the segmentation label.
  Returns:
      result: A 2D array with floating type. The element of the array
      is the color indexed by the corresponding element in the input label
      to the PASCAL color map.
  Raises:
      ValueError: If label is not of rank 2 or its value is larger than color
      map maximum entry.
  """
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_pascal_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')

    return colormap[label]


def vis_segmentation(image, seg_map):
    """Visualizes input image, segmentation map and overlay view."""
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')

    plt.subplot(grid_spec[1])
    seg_image = label_to_color_image(seg_map).astype(np.uint8)
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    plt.subplot(grid_spec[2])
    plt.imshow(image)
    plt.imshow(seg_image, alpha=0.7)
    plt.axis('off')
    plt.title('segmentation overlay')

    unique_labels = np.unique(seg_map)
    ax = plt.subplot(grid_spec[3])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0)
    plt.grid('off')
    plt.show()

LABEL_NAMES = np.asarray(
    ['background', 'blackboard','screen'])
# LABEL_NAMES = np.asarray(
#     ['background', 'blackboard','screen'])

FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)



download_path =  r"D:\python_project\deeplabv3+\blackboard_v2.pb"

MODEL = DeepLabModel(download_path)
print('model loaded successfully!')


##
def run_visualization(imagefile):
    """
  DeepLab 語義分割,並可視化結果.
  """
    orignal_im = Image.open(imagefile)
    print('running deeplab on image %s...' % imagefile)
    resized_im, seg_map = MODEL.run(orignal_im)
    print(seg_map.shape)

    vis_segmentation(resized_im, seg_map)


images_dir = r'D:\python_project\deeplabv3+\test_img'  # 測試圖片目錄所在位置
images = sorted(os.listdir(images_dir))
for imgfile in images:
    run_visualization(os.path.join(images_dir, imgfile))

print('Done.')

  需要注意的兩點:

  1.images_dir 修改為自己存圖片的dir

  2.INPUT_SIZE = 1920修改自己圖片的hw最大的一個

  測試結果展示