目標檢測 的標註數據 .xml 轉為 tfrecord 的格式用於 TensorFlow 訓練
阿新 • • 發佈:2019-01-09
split() leo ofa monit dir txt dining bus not in
將目標檢測 的標註數據 .xml 轉為 tfrecord 的格式用於 TensorFlow 訓練。
import xml.etree.ElementTree as ET import numpy as np import os import tensorflow as tf from PIL import Image classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] def convert(size, box): dw = 1./size[0] dh = 1./size[1] x = (box[0] + box[1])/2.0 y = (box[2] + box[3])/2.0 w = box[1] - box[0] h = box[3] - box[2] x = x*dw w = w*dw y = y*dh h = h*dh return [x, y, w, h] def convert_annotation(image_id): in_file = open(‘F:/xml/%s.xml‘%(image_id)) tree = ET.parse(in_file) root = tree.getroot() size = root.find(‘size‘) w = int(size.find(‘width‘).text) h = int(size.find(‘height‘).text) bboxes = [] for i, obj in enumerate(root.iter(‘object‘)): if i > 29: break difficult = obj.find(‘difficult‘).text cls = obj.find(‘name‘).text if cls not in classes or int(difficult) == 1: continue cls_id = classes.index(cls) xmlbox = obj.find(‘bndbox‘) b = (float(xmlbox.find(‘xmin‘).text), float(xmlbox.find(‘xmax‘).text), float(xmlbox.find(‘ymin‘).text), float(xmlbox.find(‘ymax‘).text)) bb = convert((w, h), b) + [cls_id] bboxes.extend(bb) if len(bboxes) < 30*5: bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5)) return np.array(bboxes, dtype=np.float32).flatten().tolist() def convert_img(image_id): image = Image.open(‘F:/snow leopard/test_im/%s.jpg‘ % (image_id)) resized_image = image.resize((416, 416), Image.BICUBIC) image_data = np.array(resized_image, dtype=‘float32‘)/255 img_raw = image_data.tobytes() return img_raw filename = os.path.join(‘test‘+‘.tfrecords‘) writer = tf.python_io.TFRecordWriter(filename) # image_ids = open(‘F:/snow leopard/test_im/%s.txt‘ % ( # year, year, image_set)).read().strip().split() image_ids = os.listdir(‘F:/snow leopard/test_im/‘) # print(filename) for image_id in image_ids: print (image_id) image_id = image_id.split(‘.‘)[0] print (image_id) xywhc = convert_annotation(image_id) img_raw = convert_img(image_id) example = tf.train.Example(features=tf.train.Features(feature={ ‘xywhc‘: tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)), ‘img‘: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), })) writer.write(example.SerializeToString()) writer.close()
Python讀取文件夾下圖片的兩種方法:
import os imagelist = os.listdir(‘./images/‘) #讀取images文件夾下所有文件的名字
import glob imagelist= sorted(glob.glob(‘./images/‘ + ‘frame_*.png‘)) #讀取帶有相同關鍵字的圖片名字,比上一中方法好
參考:
https://blog.csdn.net/CV_YOU/article/details/80778392
https://github.com/raytroop/YOLOv3_tf
目標檢測 的標註數據 .xml 轉為 tfrecord 的格式用於 TensorFlow 訓練