1. 程式人生 > >deeplabv3+ demo測試圖像分割

deeplabv3+ demo測試圖像分割

target farm run ges obi pyplot ids base overlay


 11 #!--*-- coding:utf-8 --*--
 12 
 13 # Deeplab Demo
 14 
 15 import os
 16 import tarfile
 17 
 18 from matplotlib import gridspec
 19 import matplotlib.pyplot as plt
 20 import numpy as np
 21 from PIL import Image
 22 import tempfile
 23 from six.moves import urllib
 24 
 25 import
tensorflow as tf 26 27 28 class DeepLabModel(object): 29 """ 30 加載 DeepLab 模型; 31 推斷 Inference. 32 """ 33 INPUT_TENSOR_NAME = ImageTensor:0 34 OUTPUT_TENSOR_NAME = SemanticPredictions:0 35 INPUT_SIZE = 513 36 FROZEN_GRAPH_NAME = frozen_inference_graph
37 38 def __init__(self, tarball_path): 39 """ 40 加載預訓練模型 41 """ 42 self.graph = tf.Graph() 43 44 graph_def = None 45 # Extract frozen graph from tar archive. 46 tar_file = tarfile.open(tarball_path) 47 for tar_info in
tar_file.getmembers(): 48 if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): 49 file_handle = tar_file.extractfile(tar_info) 50 graph_def = tf.GraphDef.FromString(file_handle.read()) 51 break 52 53 tar_file.close() 54 55 if graph_def is None: 56 raise RuntimeError(Cannot find inference graph in tar archive.) 57 58 with self.graph.as_default(): 59 tf.import_graph_def(graph_def, name=‘‘) 60 61 self.sess = tf.Session(graph=self.graph) 62 63 64 def run(self, image): 65 """ 66 68 Args: 69 image: 轉換為PIL.Image 類,不能直接用圖片,原始圖片 70 71 Returns: 72 resized_image: RGB image resized from original input image. 73 seg_map: Segmentation map of `resized_image`. 74 """ 75 width, height = image.size 76 resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) 77 target_size = (int(resize_ratio * width), int(resize_ratio * height)) 78 resized_image = image.convert(RGB).resize(target_size, Image.ANTIALIAS) 79 batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME, 80 feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) 81 seg_map = batch_seg_map[0] 82 return resized_image, seg_map 83 84 85 def create_pascal_label_colormap(): 86 """ 87 Creates a label colormap used in PASCAL VOC segmentation benchmark. 88 89 Returns: 90 A Colormap for visualizing segmentation results. 91 """ 92 colormap = np.zeros((256, 3), dtype=int) 93 ind = np.arange(256, dtype=int) 94 95 for shift in reversed(range(8)): 96 for channel in range(3): 97 colormap[:, channel] |= ((ind >> channel) & 1) << shift 98 ind >>= 3 99 100 return colormap 101 102 103 def label_to_color_image(label): 104 """ 105 Adds color defined by the dataset colormap to the label. 106 107 Args: 108 label: A 2D array with integer type, storing the segmentation label. 109 110 Returns: 111 result: A 2D array with floating type. The element of the array 112 is the color indexed by the corresponding element in the input label 113 to the PASCAL color map. 114 115 Raises: 116 ValueError: If label is not of rank 2 or its value is larger than color 117 map maximum entry. 118 """ 119 if label.ndim != 2: 120 raise ValueError(Expect 2-D input label) 121 122 colormap = create_pascal_label_colormap() 123 124 if np.max(label) >= len(colormap): 125 raise ValueError(label value too large.) 126 127 return colormap[label] 128 129 130 def vis_segmentation(image, seg_map, imagefile): 131 """可視化三種圖像.""" 132 plt.figure(figsize=(15, 5)) 133 grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1]) 134 135 plt.subplot(grid_spec[0]) 136 plt.imshow(image) 137 plt.axis(off) 138 plt.title(input image) 139 140 plt.subplot(grid_spec[1]) 141 seg_image = label_to_color_image(seg_map).astype(np.uint8) 142 # seg_image = label_to_color_image(seg_map) 143 # seg_image.save(‘/str(ss)+imagefile‘) 144 plt.imshow(seg_image) 145 plt.savefig(./+imagefile+.png) 146 147 plt.axis(off) 148 plt.title(segmentation map) 149 150 plt.subplot(grid_spec[2]) 151 plt.imshow(image) 152 plt.imshow(seg_image, alpha=0.7) 153 plt.axis(off) 154 plt.title(segmentation overlay) 155 156 unique_labels = np.unique(seg_map) 157 ax = plt.subplot(grid_spec[3]) 158 plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation=nearest) 159 ax.yaxis.tick_right() 160 plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) 161 plt.xticks([], []) 162 ax.tick_params(width=0.0) 163 plt.grid(off) 164 plt.show() 165 166 167 ## 168 LABEL_NAMES = np.asarray([background, aeroplane, bicycle, bird, boat, bottle, bus, 169 car, cat, chair, cow, diningtable, dog, horse, motorbike, 170 person, pottedplant, sheep, sofa, train, tv ]) 171 172 FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) 173 FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) 174 175 176 ## Tensorflow 提供的模型下載 177 MODEL_NAME = xception_coco_voctrainval 178 # [‘mobilenetv2_coco_voctrainaug‘, ‘mobilenetv2_coco_voctrainval‘, ‘xception_coco_voctrainaug‘, ‘xception_coco_voctrainval‘] 179 180 _DOWNLOAD_URL_PREFIX = http://download.tensorflow.org/models/ 181 _MODEL_URLS = {mobilenetv2_coco_voctrainaug: deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz, 182 mobilenetv2_coco_voctrainval: deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz, 183 xception_coco_voctrainaug: deeplabv3_pascal_train_aug_2018_01_04.tar.gz, 184 xception_coco_voctrainval: deeplabv3_pascal_trainval_2018_01_04.tar.gz, } 185 186 187 _TARBALL_NAME = deeplab_model.tar.gz 188 189 # model_dir = tempfile.mkdtemp() 190 model_dir = ./ 191 # tf.gfile.MakeDirs(model_dir) 192 193 # 194 download_path = os.path.join(model_dir, _TARBALL_NAME) 195 print(downloading model, this might take a while...) 196 # urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path) 197 print(download completed! loading DeepLab model...) 198 199 200 201 # model_dir = ‘/‘ 202 203 # download_path = os.path.join(model_dir, _MODEL_URLS[MODEL_NAME]) 204 MODEL = DeepLabModel(./deeplab_model.tar.gz) 205 # MODEL = ‘./deeplab_model.tar.gz‘ 206 print(model loaded successfully!) 207 208 209 ## 210 def run_visualization(imagefile): 211 """ 212 DeepLab 語義分割,並可視化結果. 213 """ 214 # orignal_im = Image.open(imagefile) 215 # print(type(orignal_im)) 216 # orignal_im.show() 217 print(running deeplab on image %s... % imagefile) 218 resized_im, seg_map = MODEL.run(Image.open(imagefile)) 219 220 221 vis_segmentation(resized_im, seg_map,imagefile) 222 223 images_dir = ./pictures 224 images = sorted(os.listdir(images_dir)) 225 print(images) 226 # img=‘205729y9fodss9ao6ol5921-150x150.jpg‘ 227 # img.show() 228 for imgfile in images: 229 # img.show() 230 run_visualization(os.path.join(images_dir, imgfile)) 231 232 print(Done.)

所使用的是deeplab_model.tar.gz,也可以修改代碼使用在標準數據集上預訓練過的模型;代碼在182行附近。

1.修改模型保存路徑

2.修改圖片路徑

3.運行即可

參考自:https://www.aiuai.cn/aifarm252.html

deeplabv3+ demo測試圖像分割