Imagenet標註檔案的Read和Write
阿新 • • 發佈:2019-01-06
image_label_util.py
#coding:utf-8
import os, cv2, shutil, random, codecs, HTMLParser
from lxml import etree
from lxml.etree import Element, SubElement, tostring
class PicAnno:
objects = []
def __init__(self, folder):
self.objects = []
self.folder = folder
def set_folder (self, folder):
self.folder = folder
def set_filename(self, filename):
self.filename = filename
def set_size(self, width, height, depth):
self.width = width
self.height = height
self.depth = depth
def add_object(self, object):
self.objects.append(object)
class PicObject:
def __init__(self, name):
self.name = name
def set_name(self, name):
self.name = name
def set_pose(self, pose):
self.pose = pose
def set_truncated(self, truncated):
self.truncated = truncated
def set_difficult(self, difficult):
self.difficult = difficult
def set_bndbox(self, xmin, ymin, xmax, ymax):
self.xmin = xmin
self.ymin = ymin
self.xmax = xmax
self.ymax = ymax
class VocUtil:
def read_anno_xml(self, xml_path):
tree = etree.parse(xml_path)
root = tree.getroot()
# gbk
# cmt = ''.join(codecs.open(xml_path, 'r', 'gbk').readlines())
# root = etree.fromstring(cmt)
picAnno = PicAnno(root.xpath('/annotation/folder')[0].text)
picAnno.set_filename(root.xpath('/annotation/filename')[0].text)
picAnno.set_size(root.xpath('/annotation/size/width')[0].text,
root.xpath('/annotation/size/height')[0].text,
root.xpath('/annotation/size/depth')[0].text)
for obj in root.xpath('/annotation/object'):
picObject = PicObject(obj.xpath('name')[0].text)
picObject.set_pose(obj.xpath('pose')[0].text)
picObject.set_truncated(obj.xpath('truncated')[0].text)
picObject.set_difficult(obj.xpath('difficult')[0].text)
picObject.set_bndbox(obj.xpath('bndbox/xmin')[0].text,
obj.xpath('bndbox/ymin')[0].text,
obj.xpath('bndbox/xmax')[0].text,
obj.xpath('bndbox/ymax')[0].text)
picAnno.add_object(picObject)
return picAnno
def parse_anno_xml(self, picAnno):
node_root = Element('annotation')
node_folder = SubElement(node_root, 'folder')
if hasattr(picAnno, 'folder') and picAnno.folder is not None:
node_folder.text = picAnno.folder
node_filename = SubElement(node_root, 'filename')
if hasattr(picAnno, 'filename') and picAnno.filename is not None:
node_filename.text = picAnno.filename
node_size = SubElement(node_root, 'size')
node_width = SubElement(node_size, 'width')
if hasattr(picAnno, 'width') and picAnno.width is not None:
node_width.text = str(picAnno.width)
node_height = SubElement(node_size, 'height')
if hasattr(picAnno, 'height') and picAnno.height is not None:
node_height.text = str(picAnno.height)
node_depth = SubElement(node_size, 'depth')
if picAnno.depth is not None:
node_depth.text = str(picAnno.depth)
if len(picAnno.objects) > 0:
for obj in picAnno.objects:
node_object = SubElement(node_root, 'object')
node_name = SubElement(node_object, 'name')
if hasattr(obj, 'name') and obj.name is not None:
node_name.text = obj.name
node_pose = SubElement(node_object, 'pose')
if hasattr(obj, 'pose') and obj.pose is not None:
node_pose.text = str(obj.pose)
node_truncated = SubElement(node_object, 'truncated')
if hasattr(obj, 'truncated') and obj.truncated is not None:
node_truncated.text = str(obj.truncated)
node_difficult = SubElement(node_object, 'difficult')
if hasattr(obj, 'difficult') and obj.difficult is not None:
node_difficult.text = str(obj.difficult)
node_bndbox = SubElement(node_object, 'bndbox')
node_xmin = SubElement(node_bndbox, 'xmin')
if hasattr(obj, 'xmin') and obj.xmin is not None:
node_xmin.text = str(obj.xmin)
node_ymin = SubElement(node_bndbox, 'ymin')
if hasattr(obj, 'ymin') and obj.ymin is not None:
node_ymin.text = str(obj.ymin)
node_xmax = SubElement(node_bndbox, 'xmax')
if hasattr(obj, 'xmax') and obj.xmax is not None:
node_xmax.text = str(obj.xmax)
node_ymax = SubElement(node_bndbox, 'ymax')
if hasattr(obj, 'ymax') and obj.ymax is not None:
node_ymax.text = str(obj.ymax)
xml = tostring(node_root, pretty_print=True)
# xml_txt = str(xml,encoding='utf-8') #window
xml_txt = str(xml).encode('utf-8') #linux
xml_txt = HTMLParser.HTMLParser().unescape(xml_txt)
return xml_txt
def save_anno_xml(self, xml_path, xml_text):
with codecs.open(xml_path, 'w', 'utf-8') as f:
f.write(xml_text)
def readFile(self, path):
file = open(path, 'r')
lines = [line.strip() for line in file.readlines()]
file.close()
return lines
def writeLines(self,file_path, lines):
file_dir = os.path.dirname(file_path)
if not os.path.exists(file_dir):
os.makedirs(file_dir)
fr = open(file_path, 'w')
for line in lines:
fr.write(line.strip() + '\n')
fr.close()
def gene_train_test_val_txt(self,anno_dir,txt_dir):
pic_names = [pic_name.split('.')[0] for pic_name in os.listdir(anno_dir) if pic_name.endswith('.xml')]
random.shuffle(pic_names)
self.writeLines(os.path.join(txt_dir, 'test.txt'), pic_names)
random.shuffle(pic_names)
self.writeLines(os.path.join(txt_dir, 'train.txt'), pic_names)
random.shuffle(pic_names)
self.writeLines(os.path.join(txt_dir, 'trainval.txt'), pic_names)
random.shuffle(pic_names)
self.writeLines(os.path.join(txt_dir, 'val.txt'), pic_names)
print('生成測試集、訓練集、訓練驗證集、驗證集完成!')