1. 程式人生 > 其它 >CV(一)之自定義資料集

CV(一)之自定義資料集

本文以 PASCAL VOC2012 資料集為例子進行說明。(下載地址:PASCAL VOC2012)

Pytorch 自定義資料集見文件:TorchVision Object Detection Finetuning Tutorial

本文將以PASCAL VOC為基礎自定義一個數據集VOCDataset,並隨機選取五張圖片給將其對應的標註轉化為矩形框畫在圖片上。

生成自定義資料集

一些需要匯入的基本庫

import os
import torch
import json
from torch.utils.data import Dataset
from PIL import Image
from os import path
from lxml import etree

# 類別資料
class_dict = {
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}

按照文件要求,在VOCDataset中實現三個方法__len____getitem__、以及get_height_and_width

初始化 VOCDataset 類

建構函式定義如下

'''
voc_root: voc 資料集的根目錄
year: 哪一個年份的資料集
transforms: 資料預處理
text_name: train.txt or val.txt 該txt檔案在資料集的 VOCdevkit\VOC2012\ImageSets\Main 資料夾下
'''
def __init__(self, voc_root, year='2012', transforms=None, text_name='train.txt'):

在建構函式中,我們主要完成以下三個功能

  1. 設定圖片路徑image_root和標註路徑anno_root
  2. 設定此次要訓練的樣本所有標註檔案路徑列表xml_list
  3. 設定要檢測的目標類別資訊class_dict
設定圖片路徑image_root和標註路徑anno_root
        # 設定資料集、圖片、標註的根目錄
        self.root = path.join(voc_root, 'VOCdevkit', f'VOC{year}')
        self.image_root = path.join(self.root, 'JPEGImages')
        self.anno_root = path.join(self.root, 'Annotations')
設定此次要訓練的樣本所有標註檔案路徑列表xml_list
        # 根據 text_name 拿到對應的標註xml檔案路徑
        text_path = path.join(self.root, 'ImageSets','Main', text_name)
        # 讀取txt檔案的每一行並生成xml標註檔案路徑存放在xml_list中
        with open(text_path) as file_reader:
            self.xml_list = [
                path.join(self.anno_root, f'{line.strip()}.xml')
                for line in file_reader.readlines() if len(line.strip()) > 0
            ]
設定要檢測的目標類別資訊class_dict
        self.class_dict = class_dict

一般使用 0 來表示當前類別是背景

獲取所有樣例條數

    def __len__(self):
        return len(self.xml_list)

樣本的條數即標註檔案列表長度

根據索引獲取指定樣本

函式定義如下

    def __getitem__(self, idx):

傳入的即為樣本的索引值,其取值範圍為 0 ~ len(xml_list)

獲取指定樣本需要分為如下兩大步

  1. 獲取圖片
  2. 獲取圖片資訊(標註資訊、索引、區域面積等)
獲取圖片

首先我們需要根據索引拿到對應標註資訊,並將其轉化為json格式
定義一個獲取json格式的annotation的方法

    def get_annotation(self, idx):
        xml_path = self.xml_list[idx]
        assert path.exists(xml_path), f'file {xml_path} not found'

        xml_reader = open(xml_path)
        xml_text = xml_reader.read()
        xml = etree.fromstring(xml_text)
        annotation = parse_xml_to_dict(xml)['annotation']

xml格式轉化為json格式函式如下

def parse_xml_to_dict(xml):
    if len(xml) == 0:
        return {xml.tag: xml.text}
    
    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else: # 一張圖片中可能標註有多個 object
            if child.tag not in result:
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    
    return {xml.tag: result}

獲取annotation

        annotation = self.get_annotation(idx)

然後我們就可以從annotation中拿到檔名稱並獲取到檔案

        image_path = path.join(self.image_root, annotation['filename'])
        image = Image.open(image_path)
獲取圖片資訊

宣告需要獲取的所有資訊

        # 生成 target
        target = {
            'boxes': [], # 標註的左上、右下座標(xmin, ymin, xmax, ymax)
            'labels': [],# 標註類別
            'image_id': [], # 圖片索引
            'area': [], # 含有目標區域的面積 (xmax-xmin) * (ymax-ymin)
            'iscrowd': [], # 是不是一堆密集的東西在一起
        }

便利所有的object


        for obj in annotation['object']:
            bndbox = obj['bndbox']
            xmin = float(bndbox['xmin'])
            ymin = float(bndbox['ymin'])
            xmax = float(bndbox['xmax'])
            ymax = float(bndbox['ymax'])
            target['boxes'].append([xmin, ymin, xmax, ymax]) # 設定有目標的座標資訊
            target['labels'].append(self.class_dict[obj['name']]) # 獲取對應的標籤
            target['area'].append((xmax - xmin) * (ymax - ymin)) # 計算面積

            # 使用 difficult(當前目標是否難以識別) 欄位來設定 iscrowd
            if 'difficult' in obj:
                target['iscrowd'].append(int(obj['difficult']))
            else:
                target['iscrowd'].append(0)

將所有資訊轉化為Tensor

        # Convert to tensor
        target['boxes'] = torch.as_tensor(target['boxes'])
        target['labels'] = torch.as_tensor(target['labels'])
        target['iscrowd'] = torch.as_tensor(target['iscrowd'])
        target['area'] = torch.as_tensor(target['area'])
        target['image_id'] = torch.tensor([idx])

如果有設定資料前處理器,則在返回資料前呼叫

        if self.transforms is not None:
            image = self.transforms(image)

返回圖片以及對應的資訊

        return image, target

根據索引獲取當前圖片的寬高

在標註資訊裡面含有圖片寬高資訊,所以可以很容易獲取到

    def get_height_and_width(self, idx):
        annotation = annotation = self.get_annotation(idx)
        # 從 annotation 中取出寬高並返回
        width = int(annotation['size']['width'])
        height = int(annotation['size']['height'])

        return height, width

以上我們就完成了資料集的定義,下面我們將使用例項程式碼來使用這個資料集

使用自定義資料集並畫上標註框

匯入一些基本庫

import random
import matplotlib.pyplot as plt
import torchvision.transforms as ts
from draw_box_utils import draw_box

生成類別資料,將 kv 互換,便於查詢

category_index = {}

category_index = {
    v: k
    for k, v in class_dict.items()
}

定義transformer,將資料轉化為Tensor

data_transform = ts.Compose([ts.ToTensor()])

由於ToTensor會將資料標準化,為了程式碼簡潔,這裡不使用

拿到資料集並將目標框以及類別畫出來

train_data_set = VOCDataset(os.getcwd(), '2012', None, 'train.txt')

for index in random.sample(range(0, len(train_data_set)), k=5):
    image, target = train_data_set[index]
    image = draw_bounding_boxes(
        np.array(image),
        target['boxes'],
        target['labels'],
    )
    plt.imshow(image)
    plt.show()

畫目標框draw_bounding_boxes程式碼如下(參考程式碼: vision/utils.py at main · pytorch/vision (github.com))


def draw_bounding_boxes(
    image,
    boxes: torch.Tensor,
    labels: Optional[List[str]] = None
) -> torch.Tensor:
    img_to_draw = Image.fromarray(image)
    img_boxes = boxes.to(torch.int64).tolist()
    draw = ImageDraw.Draw(img_to_draw)

    for i, bbox in enumerate(img_boxes):
        draw.rectangle(bbox, width=2, outline='red')
        margin = 2
        draw.text((bbox[0] + margin, bbox[1] + margin),  category_index[labels[i] - 1], fill='red')


    return np.array(img_to_draw)

這樣就完成了整個流程了!

執行與測試

可見執行結果正確!