PyTorch 自定義資料集
阿新 • • 發佈:2021-02-24
## 準備資料
準備 [COCO128](https://www.kaggle.com/ultralytics/coco128) 資料集,其是 [COCO](https://cocodataset.org) train2017 前 128 個數據。按 YOLOv5 組織的目錄:
```bash
$ tree ~/datasets/coco128 -L 2
/home/john/datasets/coco128
├── images
│ └── train2017
│ ├── ...
│ └── 000000000650.jpg
├── labels
│ └── train2017
│ ├── ...
│ └── 000000000650.txt
├── LICENSE
└── README.txt
```
詳見 [Train Custom Data](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data)。
## 定義 Dataset
`torch.utils.data.Dataset` 是一個數據集的抽象類。自定義資料集時,需繼承 `Dataset` 並覆蓋如下方法:
- `__len__`: `len(dataset)` 獲取資料集大小。
- `__getitem__`: `dataset[i]` 訪問第 `i` 個數據。
詳見:
- [torch.utils.data.Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
- [torchvision.datasets.vision.VisionDataset](https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py)
自定義實現 YOLOv5 資料集的例子:
```python
import os
from pathlib import Path
from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
import torchvision
from PIL import Image
class YOLOv5(torchvision.datasets.vision.VisionDataset):
def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(YOLOv5, self).__init__(root, transforms, transform, target_transform)
images_dir = Path(root) / 'images' / name
labels_dir = Path(root) / 'labels' / name
self.images = [n for n in images_dir.iterdir()]
self.labels = []
for image in self.images:
base, _ = os.path.splitext(os.path.basename(image))
label = labels_dir / f'{base}.txt'
self.labels.append(label if label.exists() else None)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
img = Image.open(self.images[idx]).convert('RGB')
label_file = self.labels[idx]
if label_file is not None: # found
with open(label_file, 'r') as f:
labels = [x.split() for x in f.read().strip().splitlines()]
labels = np.array(labels, dtype=np.float32)
else: # missing
labels = np.zeros((0, 5), dtype=np.float32)
boxes = []
classes = []
for label in labels:
x, y, w, h = label[1:]
boxes.append([
(x - w/2) * img.width,
(y - h/2) * img.height,
(x + w/2) * img.width,
(y + h/2) * img.height])
classes.append(label[0])
target = {}
target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
target["labels"] = torch.as_tensor(classes, dtype=torch.int64)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self) -> int:
return len(self.images)
```
以上實現,繼承了 `VisionDataset` 子類。其 `__getitem__` 返回了:
- image: PIL Image, 大小為 `(H, W)`
- target: `dict`, 含以下欄位:
- `boxes` (`FloatTensor[N, 4]`): 真實標註框 `[x1, y1, x2, y2]`, `x` 範圍 `[0,W]`, `y` 範圍 `[0,H]`
- `labels` (`Int64Tensor[N]`): 上述標註框的類別標識
## 讀取 Dataset
```python
dataset = YOLOv5(Path.home() / 'datasets/coco128', 'train2017')
print(f'dataset: {len(dataset)}')
print(f'dataset[0]: {dataset[0]}')
```
輸出:
```bash
dataset: 128
datas