Tensorflow Object Detection 生成自己的mask_rcnn資料集
阿新 • • 發佈:2018-12-20
此文章參考自大神shirhe的github,由於之前大神寫的有點bug,mask訓練出來後檢測不會出來mask掩碼,所以自己研究了下改的新的(最近大神也修改了這個bug),轉載請註明出處。
有三個python檔案需要建立(如果是labelme生成的是xml檔案,另附一份xml轉json程式碼)
json檔案格式請參考上述github中的json檔案
string_int_label_map_pb2.py
# Generated by the protocol buffer compiler. DO NOT EDIT! # source: object_detection/protos/string_int_label_map.proto import sys _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor.FileDescriptor( name='object_detection/protos/string_int_label_map.proto', package='object_detection.protos', syntax='proto2', serialized_options=None, serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') ) _STRINGINTLABELMAPITEM = _descriptor.Descriptor( name='StringIntLabelMapItem', full_name='object_detection.protos.StringIntLabelMapItem', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, number=1, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, number=2, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, number=3, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], nested_types=[], enum_types=[ ], serialized_options=None, is_extendable=False, syntax='proto2', extension_ranges=[], oneofs=[ ], serialized_start=79, serialized_end=150, ) _STRINGINTLABELMAP = _descriptor.Descriptor( name='StringIntLabelMap', full_name='object_detection.protos.StringIntLabelMap', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, number=1, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], nested_types=[], enum_types=[ ], serialized_options=None, is_extendable=False, syntax='proto2', extension_ranges=[], oneofs=[ ], serialized_start=152, serialized_end=233, ) _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP _sym_db.RegisterFileDescriptor(DESCRIPTOR) StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( DESCRIPTOR = _STRINGINTLABELMAPITEM, __module__ = 'object_detection.protos.string_int_label_map_pb2' # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) )) _sym_db.RegisterMessage(StringIntLabelMapItem) StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( DESCRIPTOR = _STRINGINTLABELMAP, __module__ = 'object_detection.protos.string_int_label_map_pb2' # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) )) _sym_db.RegisterMessage(StringIntLabelMap) # @@protoc_insertion_point(module_scope)
read_pbtxt_file.py
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Aug 26 13:42:50 2018 @author: shirhe-lyh """ """A tool to read .pbtxt file. See Details at: TensorFlow models/research/object_detetion/protos/string_int_label_pb2.py TensorFlow models/research/object_detection/utils/label_map_util.py """ import tensorflow as tf from google.protobuf import text_format import string_int_label_map_pb2 def load_pbtxt_file(path): """Read .pbtxt file. Args: path: Path to StringIntLabelMap proto text file (.pbtxt file). Returns: A StringIntLabelMapProto. Raises: ValueError: If path is not exist. """ if not tf.gfile.Exists(path): raise ValueError('`path` is not exist.') with tf.gfile.GFile(path, 'r') as fid: pbtxt_string = fid.read() pbtxt = string_int_label_map_pb2.StringIntLabelMap() try: text_format.Merge(pbtxt_string, pbtxt) except text_format.ParseError: pbtxt.ParseFromString(pbtxt_string) return pbtxt def get_label_map_dict(path): """Reads a .pbtxt file and returns a dictionary. Args: path: Path to StringIntLabelMap proto text file. Returns: A dictionary mapping class names to indices. """ pbtxt = load_pbtxt_file(path) result_dict = {} for item in pbtxt.item: result_dict[item.name] = item.id return result_dict
create_mask_rcnn_tfrecord.py
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Aug 26 10:57:09 2018 @author: shirhe-lyh """ """Convert raw dataset to TFRecord for object_detection. Please note that this tool only applies to labelme's annotations(json file). Example usage: python3 create_tf_record.py \ --images_dir=your absolute path to read images. --annotations_json_dir=your path to annotaion json files. --label_map_path=your path to label_map.pbtxt --output_path=your path to write .record. """ from numpy import * set_printoptions(threshold=NaN) import cv2 import glob import hashlib import io import numpy as np import os import PIL.Image import tensorflow as tf import json import read_pbtxt_file flags = tf.app.flags flags.DEFINE_string('images_dir', '/home/ai/Downloads/collection_mask_teeth/imgResize', 'Path to images directory.') flags.DEFINE_string('annotations_json_dir', '/home/ai/Downloads/collection_mask_teeth/AnnotationJson', 'Path to annotations directory.') flags.DEFINE_string('label_map_path', '../data/teeth_label_map.pbtxt', 'Path to label map proto.') flags.DEFINE_string('output_path', '../All_tf_record/teeth_mask.record', 'Path to the output tfrecord.') FLAGS = flags.FLAGS def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def int64_list_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def bytes_list_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) def float_list_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def create_tf_example(annotation_dict, label_map_dict=None): """Converts image and annotations to a tf.Example proto. Args: annotation_dict: A dictionary containing the following keys: ['height', 'width', 'filename', 'sha256_key', 'encoded_jpg', 'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks', 'class_names']. label_map_dict: A dictionary maping class_names to indices. Returns: example: The converted tf.Example. Raises: ValueError: If label_map_dict is None or is not containing a class_name. """ if annotation_dict is None: return None if label_map_dict is None: raise ValueError('`label_map_dict` is None') height = annotation_dict.get('height', None) width = annotation_dict.get('width', None) filename = annotation_dict.get('filename', None) sha256_key = annotation_dict.get('sha256_key', None) encoded_jpg = annotation_dict.get('encoded_jpg', None) image_format = annotation_dict.get('format', None) xmins = annotation_dict.get('xmins', None) xmaxs = annotation_dict.get('xmaxs', None) ymins = annotation_dict.get('ymins', None) ymaxs = annotation_dict.get('ymaxs', None) masks = annotation_dict.get('masks', None) class_names = annotation_dict.get('class_names', None) labels = [] for class_name in class_names: label = label_map_dict.get(class_name, 'None') if label is None: raise ValueError('`label_map_dict` is not containing {}.'.format( class_name)) # labels.append(label) labels.append(label) print('image is {},label is {},'.format(filename, class_names)) encoded_masks = [] for mask in masks: pil_image = PIL.Image.fromarray(mask) # pil_image = PIL.Image.fromarray(mask) output_io = io.BytesIO() pil_image.save(output_io, format='PNG') encoded_masks.append(output_io.getvalue()) feature_dict = { 'image/height': int64_feature(height), 'image/width': int64_feature(width), 'image/filename': bytes_feature(filename.encode('utf8')), 'image/source_id': bytes_feature(filename.encode('utf8')), 'image/key/sha256': bytes_feature(sha256_key.encode('utf8')), 'image/encoded': bytes_feature(encoded_jpg), 'image/format': bytes_feature(image_format.encode('utf8')), 'image/object/bbox/xmin': float_list_feature(xmins), 'image/object/bbox/xmax': float_list_feature(xmaxs), 'image/object/bbox/ymin': float_list_feature(ymins), 'image/object/bbox/ymax': float_list_feature(ymaxs), 'image/object/mask': bytes_list_feature(encoded_masks), 'image/object/class/label': int64_list_feature(labels)} example = tf.train.Example(features=tf.train.Features( feature=feature_dict)) return example def _get_annotation_dict(images_dir, annotation_json_path): """Get boundingboxes and masks. Args: images_dir: Path to images directory. annotation_json_path: Path to annotated json file corresponding to the image. The json file annotated by labelme with keys: ['lineColor', 'imageData', 'fillColor', 'imagePath', 'shapes', 'flags']. Returns: annotation_dict: A dictionary containing the following keys: ['height', 'width', 'filename', 'sha256_key', 'encoded_jpg', 'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks', 'class_names']. # # Raises: # ValueError: If images_dir or annotation_json_path is not exist. """ # if not os.path.exists(images_dir): # raise ValueError('`images_dir` is not exist.') # # if not os.path.exists(annotation_json_path): # raise ValueError('`annotation_json_path` is not exist.') if (not os.path.exists(images_dir) or not os.path.exists(annotation_json_path)): return None with open(annotation_json_path, 'r') as f: json_text = json.load(f) shapes = json_text.get('shapes', None) if shapes is None: return None image_relative_path = json_text.get('imagePath', None) if image_relative_path is None: return None image_name = image_relative_path.split('/')[-1] image_path = os.path.join(images_dir, image_name) if not os.path.exists(image_path): return None # change_jpg = cv2.imread(image_path) # image_path = cv2.imwrite() # image_format = image_name.split('.')[-1].replace('jpg', 'jpeg') with tf.gfile.GFile(image_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image_ = PIL.Image.open(encoded_jpg_io) if image_.format != 'JPEG': raise ValueError('Image format not JPEG ') image = cv2.imread(image_path) height = image.shape[0] width = image.shape[1] key = hashlib.sha256(encoded_jpg).hexdigest() xmins = [] xmaxs = [] ymins = [] ymaxs = [] masks = [] class_names = [] hole_polygons = [] # mask = np.zeros(image.shape[:2]) for mark in shapes: class_name = mark.get('label') class_names.append(class_name) # class_names.append('ltrcat_dog') polygon = mark.get('points') # print('polygon == ',polygon) polygon = np.array(polygon,dtype=np.int32) if class_name == 'hole': hole_polygons.append(polygon) else: mask = np.zeros(image.shape[:2], dtype='uint8') # mask = np.zeros(image.shape[:2]) # print('polygon======',[polygon]) cv2.fillPoly(mask, [polygon], 1) masks.append(mask) # Boundingbox x = polygon[:, 0] y = polygon[:, 1] xmin = np.min(x) xmax = np.max(x) ymin = np.min(y) ymax = np.max(y) xmins.append(float(xmin) / width) xmaxs.append(float(xmax) / width) ymins.append(float(ymin) / height) ymaxs.append(float(ymax) / height) # masks.append(mask) # with tf.gfile.GFile(mask_pic, 'rb') as fid: # encoded_mask_png = fid.read() # mask_pic = cv2.cvtColor(mask_pic,cv2.COLOR_GRAY2BGR) # print(mask_pic.shape) # cv2.imwrite('/home/tqhy/mask_is.png',mask_pic) # plt.imshow(mask_pic) # plt.show() # print('this is mask_pic =={}=='.format(mask_pic)) # encoded_png_io = io.BytesIO(encoded_mask_png) # mask = PIL.Image.open(encoded_png_io) # mask_np = np.asarray(mask) # mask_remapped = (mask_pic != 0).astype(np.uint8) # masks.append(mask) # Remove holes in mask for mask in masks: cv2.fillPoly(mask, hole_polygons, 0) # print('masks-----list == ',masks) annotation_dict = {'height': height, 'width': width, 'filename': image_name, 'sha256_key': key, 'encoded_jpg': encoded_jpg, # 'format': image_format, 'format': image_path, 'xmins': xmins, 'xmaxs': xmaxs, 'ymins': ymins, 'ymaxs': ymaxs, 'masks': masks, 'class_names': class_names # 'class_names': 'ltrcat_dog' } return annotation_dict def main(_): if not os.path.exists(FLAGS.images_dir): raise ValueError('`images_dir` is not exist.') if not os.path.exists(FLAGS.annotations_json_dir): raise ValueError('`annotations_json_dir` is not exist.') if not os.path.exists(FLAGS.label_map_path): raise ValueError('`label_map_path` is not exist.') label_map = read_pbtxt_file.get_label_map_dict(FLAGS.label_map_path) writer = tf.python_io.TFRecordWriter(FLAGS.output_path) num_annotations_skiped = 0 annotations_json_path = os.path.join(FLAGS.annotations_json_dir, '*.json') for i, annotation_file in enumerate(glob.glob(annotations_json_path)): if i % 100 == 0: print('On image %d'%i) annotation_dict = _get_annotation_dict( FLAGS.images_dir, annotation_file) # print('=====masks====',annotation_dict['masks']) if annotation_dict is None: num_annotations_skiped += 1 continue tf_example = create_tf_example(annotation_dict, label_map) writer.write(tf_example.SerializeToString()) print('Successfully created TFRecord to {}.'.format(FLAGS.output_path)) if __name__ == '__main__': tf.app.run()
大神shirhe提供了json檔案型別的生成tfrecord方式,由於我的工作環境是labelme,輸出預設為xml格式,所以將xml轉換成json檔案。mask訓練相當消耗記憶體,我決定將圖片縮小到原來的一半,注意影象中的座標寬高等尺寸要和json中尺寸對應。要麼都縮小,要麼都不縮。
# -*- coding:utf-8 -*-
import xmltodict
import json
import xml.dom.minidom as xmldom
import cv2
import os
def xml_covert_json(xml_path_dir,json_save_path,imgPath,save):
xml_list = [os.path.join(xml_path_dir,xml_file) for xml_file in os.listdir(xml_path_dir)]
for xml_path in xml_list:
DOMTree = xmldom.parse(xml_path)
collection = DOMTree.documentElement
# filename = collection.getElementsByTagName('filename')[0].childNodes[0].data
filename = os.path.basename(xml_path).split('.')[0]+'.jpg'
root = collection.getElementsByTagName('object')
shapes_dic = dict()
shapes_list = []
print(os.path.basename(os.path.join(imgPath, filename)))
img_np = cv2.imread(os.path.join(imgPath, filename))
height,width= img_np.shape[0],img_np.shape[1]
# print('height,width',filename,height,width)
img_np = cv2.resize(img_np,(int(width/2),int(height/2)))
cv2.imwrite(os.path.join(save, filename), img_np)
for i in root:
polygons = i.getElementsByTagName('polygon')
label = i.getElementsByTagName('name')[0].childNodes[0].data
isDelete = i.getElementsByTagName('deleted')[0].childNodes[0].data
if int(isDelete) == 0:
points = []
for polygon in polygons:
pts = polygon.getElementsByTagName('pt')
for pt in pts:
# The original size of the original image to get the coordinates
# x = int(pt.getElementsByTagName('x')[0].childNodes[0].data)
# y = int(pt.getElementsByTagName('y')[0].childNodes[0].data)
# Half of the original size gets the coordinates
x = int(int(pt.getElementsByTagName('x')[0].childNodes[0].data)/2)
y = int(int(pt.getElementsByTagName('y')[0].childNodes[0].data)/2)
points.append([x, y])
try:
label_print = int(label)
except:
print('filename is {},error label is {}'.format(filename,label))
# cv2.drawContours(img_np,[np.array(points)],0,(255,0,0),2)
# cv2.imwrite(os.path.join(save,filename),img_np)
shapes_dic['points'] = points
shapes_dic['label'] = label
shapes_dic['line_color'] = 'null'
shapes_dic['fill_color'] = 'null'
shapes_list.append(shapes_dic.copy())
# print(shapes_list)
xml_json = os.path.basename(xml_path).split('.')[0]+'.json'
with open(os.path.join(json_save_path,xml_json),'w') as f:
json.dump({'lineColor': [0, 255, 0, 128],
'imageData': 'imagedata',
'fillColor': [255, 0, 0, 128],
'imagePath': filename,
'shapes': shapes_list,
'flag': {}},f)
if __name__ == '__main__':
xml_path = '/home/ai/Downloads/collection_mask_teeth/box'
json_save_path = '/home/ai/Downloads/collection_mask_teeth/AnnotationJson'
imgPath = '/home/ai/Downloads/collection_mask_teeth/img'
save = '/home/ai/Downloads/collection_mask_teeth/imgResize'
xml_covert_json(xml_path,json_save_path,imgPath,save)