1. 程式人生 > >Imagenet標註檔案的Read和Write

Imagenet標註檔案的Read和Write

image_label_util.py

#coding:utf-8
import os, cv2, shutil, random, codecs, HTMLParser
from lxml import etree
from lxml.etree import Element, SubElement, tostring

class PicAnno:
    objects = []

    def __init__(self, folder):
        self.objects = []
        self.folder = folder

    def set_folder
(self, folder):
self.folder = folder def set_filename(self, filename): self.filename = filename def set_size(self, width, height, depth): self.width = width self.height = height self.depth = depth def add_object(self, object): self.objects.append(object) class
PicObject:
def __init__(self, name): self.name = name def set_name(self, name): self.name = name def set_pose(self, pose): self.pose = pose def set_truncated(self, truncated): self.truncated = truncated def set_difficult(self, difficult): self.difficult = difficult def
set_bndbox(self, xmin, ymin, xmax, ymax):
self.xmin = xmin self.ymin = ymin self.xmax = xmax self.ymax = ymax class VocUtil: def read_anno_xml(self, xml_path): tree = etree.parse(xml_path) root = tree.getroot() # gbk # cmt = ''.join(codecs.open(xml_path, 'r', 'gbk').readlines()) # root = etree.fromstring(cmt) picAnno = PicAnno(root.xpath('/annotation/folder')[0].text) picAnno.set_filename(root.xpath('/annotation/filename')[0].text) picAnno.set_size(root.xpath('/annotation/size/width')[0].text, root.xpath('/annotation/size/height')[0].text, root.xpath('/annotation/size/depth')[0].text) for obj in root.xpath('/annotation/object'): picObject = PicObject(obj.xpath('name')[0].text) picObject.set_pose(obj.xpath('pose')[0].text) picObject.set_truncated(obj.xpath('truncated')[0].text) picObject.set_difficult(obj.xpath('difficult')[0].text) picObject.set_bndbox(obj.xpath('bndbox/xmin')[0].text, obj.xpath('bndbox/ymin')[0].text, obj.xpath('bndbox/xmax')[0].text, obj.xpath('bndbox/ymax')[0].text) picAnno.add_object(picObject) return picAnno def parse_anno_xml(self, picAnno): node_root = Element('annotation') node_folder = SubElement(node_root, 'folder') if hasattr(picAnno, 'folder') and picAnno.folder is not None: node_folder.text = picAnno.folder node_filename = SubElement(node_root, 'filename') if hasattr(picAnno, 'filename') and picAnno.filename is not None: node_filename.text = picAnno.filename node_size = SubElement(node_root, 'size') node_width = SubElement(node_size, 'width') if hasattr(picAnno, 'width') and picAnno.width is not None: node_width.text = str(picAnno.width) node_height = SubElement(node_size, 'height') if hasattr(picAnno, 'height') and picAnno.height is not None: node_height.text = str(picAnno.height) node_depth = SubElement(node_size, 'depth') if picAnno.depth is not None: node_depth.text = str(picAnno.depth) if len(picAnno.objects) > 0: for obj in picAnno.objects: node_object = SubElement(node_root, 'object') node_name = SubElement(node_object, 'name') if hasattr(obj, 'name') and obj.name is not None: node_name.text = obj.name node_pose = SubElement(node_object, 'pose') if hasattr(obj, 'pose') and obj.pose is not None: node_pose.text = str(obj.pose) node_truncated = SubElement(node_object, 'truncated') if hasattr(obj, 'truncated') and obj.truncated is not None: node_truncated.text = str(obj.truncated) node_difficult = SubElement(node_object, 'difficult') if hasattr(obj, 'difficult') and obj.difficult is not None: node_difficult.text = str(obj.difficult) node_bndbox = SubElement(node_object, 'bndbox') node_xmin = SubElement(node_bndbox, 'xmin') if hasattr(obj, 'xmin') and obj.xmin is not None: node_xmin.text = str(obj.xmin) node_ymin = SubElement(node_bndbox, 'ymin') if hasattr(obj, 'ymin') and obj.ymin is not None: node_ymin.text = str(obj.ymin) node_xmax = SubElement(node_bndbox, 'xmax') if hasattr(obj, 'xmax') and obj.xmax is not None: node_xmax.text = str(obj.xmax) node_ymax = SubElement(node_bndbox, 'ymax') if hasattr(obj, 'ymax') and obj.ymax is not None: node_ymax.text = str(obj.ymax) xml = tostring(node_root, pretty_print=True) # xml_txt = str(xml,encoding='utf-8') #window xml_txt = str(xml).encode('utf-8') #linux xml_txt = HTMLParser.HTMLParser().unescape(xml_txt) return xml_txt def save_anno_xml(self, xml_path, xml_text): with codecs.open(xml_path, 'w', 'utf-8') as f: f.write(xml_text) def readFile(self, path): file = open(path, 'r') lines = [line.strip() for line in file.readlines()] file.close() return lines def writeLines(self,file_path, lines): file_dir = os.path.dirname(file_path) if not os.path.exists(file_dir): os.makedirs(file_dir) fr = open(file_path, 'w') for line in lines: fr.write(line.strip() + '\n') fr.close() def gene_train_test_val_txt(self,anno_dir,txt_dir): pic_names = [pic_name.split('.')[0] for pic_name in os.listdir(anno_dir) if pic_name.endswith('.xml')] random.shuffle(pic_names) self.writeLines(os.path.join(txt_dir, 'test.txt'), pic_names) random.shuffle(pic_names) self.writeLines(os.path.join(txt_dir, 'train.txt'), pic_names) random.shuffle(pic_names) self.writeLines(os.path.join(txt_dir, 'trainval.txt'), pic_names) random.shuffle(pic_names) self.writeLines(os.path.join(txt_dir, 'val.txt'), pic_names) print('生成測試集、訓練集、訓練驗證集、驗證集完成!')