Tensorflow Object Detection 生成自己的tfrecord訓練資料集
阿新 • • 發佈:2018-12-20
Object Detection API谷歌
該文章部分參考別的大佬的,由於忘了內容出處,所以沒有加轉載連結,請諒解,有原創作者看到可以聯絡我新增。
========轉載請註明出處==========
此python檔案放在dataset_tools下面
生成自己訓練的資料集主要看個人annotation檔案是什麼格式的。我這裡的每張圖都有自己的annotation檔案,例如:
圖片xxx.jpg,其annotation檔案為xxx.box
box檔案內容為:
Xmin Ymin Xmax Ymax label 如下圖:如果有多個label ,可以繼續追加在下一行:
Xmin Ymin Xmax Ymax label \n
Xmin Ymin Xmax Ymax label
from __future__ import absolute_import from __future__ import division from __future__ import print_function import hashlib import io import os import PIL.Image import tensorflow as tf import pandas as pd import cv2 from functools import reduce import operator from object_detection.utils import dataset_util flags = tf.app.flags flags.DEFINE_string('train_imgs_dir', '/home/ai/Downloads/competition_change_box_img/img', 'Root directory to bc train dataset.') flags.DEFINE_string('train_labels', '/home/ai/Downloads/competition_change_box_img/box', '(Relative) path to annotations directory.') flags.DEFINE_string('train_output', '../All_tf_record/competition_img_test.record', 'Path to output TFRecord') FLAGS = flags.FLAGS def create_coordinate_info_of_content_list(image_dir,label_dir): content_list_all = [] for item,file_name in enumerate(os.listdir(label_dir)): img = cv2.imread(os.path.join(image_dir,file_name.replace('.box','.jpg'))) height = img.shape[0] width = img.shape[1] deepth = img.shape[2] content_list = [[file_name.replace('.box', '.jpg'), height, width, deepth]] with open(os.path.join(label_dir,file_name), 'r') as f: lines = f.readlines() for line in lines: new_line = line.split(' ')[:] content_one = [new_line[0],new_line[1],new_line[2],new_line[3],new_line[4]] content_list.append(content_one) a = reduce(operator.add,content_list) content_list_all.append(a) return content_list_all def create_tf_example(content_list, imgs_dir): height = int(content_list[1]) width = int(content_list[2]) filename = content_list[0] img_path = os.path.join(imgs_dir, filename) with tf.gfile.GFile(img_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = PIL.Image.open(encoded_jpg_io) if image.format != 'JPEG': raise ValueError('Image format not JPEG') key = hashlib.sha256(encoded_jpg).hexdigest() xmin = [] ymin = [] xmax = [] ymax = [] classes = [] classes_text = [] box_num = int((len(content_list) - 4) / 5) #一張圖上可能有多個label for i in range(box_num): xmin.append(float(content_list[5 * i + 4 + 0]) / width) ymin.append(float(content_list[5 * i + 4 + 1]) / height) xmax.append(float(content_list[5 * i + 4 + 2]) / width) ymax.append(float(content_list[5 * i + 4 + 3]) / height) classes_text.append(content_list[5 * i + 4 + 4].encode('utf8')) classes.append(classMap[content_list[5 * i + 4 + 4]]) print('the class id is {} '.format(classMap[content_list[5 * i + 4 + 4]])) example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature( filename.encode('utf8')), 'image/source_id': dataset_util.bytes_feature( filename.encode('utf8')), 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmin), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmax), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymin), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymax), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), })) return example def main(_): # train tfrecord generate print("Reading from {}".format(FLAGS.train_imgs_dir)) writer = tf.python_io.TFRecordWriter(FLAGS.train_output) content_list_all = create_coordinate_info_of_content_list(FLAGS.train_imgs_dir, FLAGS.train_labels) for line in content_list_all: content_list = line tf_example = create_tf_example(content_list, FLAGS.train_imgs_dir) writer.write(tf_example.SerializeToString()) writer.close() if __name__ == '__main__': tf.app.run()