1. 程式人生 > 其它 >Faster R-CNN 自定義Dataset

Faster R-CNN 自定義Dataset

技術標籤:faster-RCNNpython深度學習

bilibili
spilt_data.py

files_path = "./VOCdevkit/VOC2012/Annotations"
if not os.path.exists(files_path):
    print("資料夾不存在")
    exit(1)
val_rate = 0.5

files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
files_num = len
(files_name) val_index = random.sample(range(0, files_num), k=int(files_num*val_rate)) train_files = [] val_files = [] for index, file_name in enumerate(files_name): if index in val_index: val_files.append(file_name) else: train_files.append(file_name) try: train_f = open
("train.txt", "x") eval_f = open("val.txt", "x") train_f.write("\n".join(train_files)) eval_f.write("\n".join(val_files)) except FileExistsError as e: print(e) exit(1)

my_dataset.py

from torch.utils.data import Dataset
import
os import torch import json from PIL import Image from lxml import etree class VOC2012DataSet(Dataset): """讀取解析PASCAL VOC2012資料集""" def __init__(self, voc_root, transforms, txt_name: str = "train.txt"): self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012") self.img_root = os.path.join(self.root, "JPEGImages") self.annotations_root = os.path.join(self.root, "Annotations") # read train.txt or val.txt file txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name) assert os.path.exists(txt_path), "not found {} file.".format(txt_name) with open(txt_path) as read: self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml") for line in read.readlines()] # read class_indict try: json_file = open('./pascal_voc_classes.json', 'r') self.class_dict = json.load(json_file) except Exception as e: print(e) exit(-1) self.transforms = transforms def __len__(self): return len(self.xml_list) def __getitem__(self, idx): # read xml xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] img_path = os.path.join(self.img_root, data["filename"]) image = Image.open(img_path) if image.format != "JPEG": raise ValueError("Image format not JPEG") boxes = [] labels = [] iscrowd = [] for obj in data["object"]: xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) boxes.append([xmin, ymin, xmax, ymax]) labels.append(self.class_dict[obj["name"]]) iscrowd.append(int(obj["difficult"])) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64) image_id = torch.tensor([idx]) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd if self.transforms is not None: image, target = self.transforms(image, target) return image, target def get_height_and_width(self, idx): # read xml xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] data_height = int(data["size"]["height"]) data_width = int(data["size"]["width"]) return data_height, data_width def parse_xml_to_dict(self, xml): """ 將xml檔案解析成字典形式,參考tensorflow的recursive_parse_xml_to_dict Args: xml: xml tree obtained by parsing XML file contents using lxml.etree Returns: Python dictionary holding XML contents. """ if len(xml) == 0: # 遍歷到底層,直接返回tag對應的資訊 return {xml.tag: xml.text} result = {} for child in xml: child_result = self.parse_xml_to_dict(child) # 遞迴遍歷標籤資訊 if child.tag != 'object': result[child.tag] = child_result[child.tag] else: if child.tag not in result: # 因為object可能有多個,所以需要放入列表裡 result[child.tag] = [] result[child.tag].append(child_result[child.tag]) return {xml.tag: result} def coco_index(self, idx): """ 該方法是專門為pycocotools統計標籤資訊準備,不對影象和標籤作任何處理 由於不用去讀取圖片,可大幅縮減統計時間 Args: idx: 輸入需要獲取影象的索引 """ # read xml xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] data_height = int(data["size"]["height"]) data_width = int(data["size"]["width"]) # img_path = os.path.join(self.img_root, data["filename"]) # image = Image.open(img_path) # if image.format != "JPEG": # raise ValueError("Image format not JPEG") boxes = [] labels = [] iscrowd = [] for obj in data["object"]: xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) boxes.append([xmin, ymin, xmax, ymax]) labels.append(self.class_dict[obj["name"]]) iscrowd.append(int(obj["difficult"])) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64) image_id = torch.tensor([idx]) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd return (data_height, data_width), target @staticmethod def collate_fn(batch): return tuple(zip(*batch)) # import transforms # from draw_box_utils import draw_box # from PIL import Image # import json # import matplotlib.pyplot as plt # import torchvision.transforms as ts # import random # # # read class_indict # category_index = {} # try: # json_file = open('./pascal_voc_classes.json', 'r') # class_dict = json.load(json_file) # category_index = {v: k for k, v in class_dict.items()} # except Exception as e: # print(e) # exit(-1) # # data_transform = { # "train": transforms.Compose([transforms.ToTensor(), # transforms.RandomHorizontalFlip(0.5)]), # "val": transforms.Compose([transforms.ToTensor()]) # } # # # load train data set # train_data_set = VOC2012DataSet(os.getcwd(), data_transform["train"], True) # print(len(train_data_set)) # for index in random.sample(range(0, len(train_data_set)), k=5): # img, target = train_data_set[index] # img = ts.ToPILImage()(img) # draw_box(img, # target["boxes"].numpy(), # target["labels"].numpy(), # [1 for i in range(len(target["labels"].numpy()))], # category_index, # thresh=0.5, # line_thickness=5) # plt.imshow(img) # plt.show()

transforms.py

class Compose(object):
    """組合多個transform函式"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class ToTensor(object):
    """將PIL影象轉為Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


class RandomHorizontalFlip(object):
    """隨機水平翻轉影象以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻轉圖片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻轉對應bbox座標資訊
            target["boxes"] = bbox
        return image, target