PyTorch官方教程(二)-DataLoadingAndProcessing
對於一個新的機器/深度學習任務, 大量的時間都會花費在資料準備上. PyTorch提供了多種輔助工具來幫助使用者更方便的處理和載入資料. 本示例主要會用到以下兩個包:
- scikit-image: 用於讀取和處理圖片
- pandas: 用於解析csv檔案
匯入下面的包
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
本示例使用的是人臉姿態的資料集, 資料集的標註資訊是由68個landmark點組成的, csv檔案的格式如下所示:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y 0805personali01.jpg,27,83,27,98, ... 84,134 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
利用如下程式碼可以快速的讀取CSV檔案裡面的標註資訊, 並且將其轉換成 (N,2) 的陣列形式, 其中, N 為 landmarks 點的個數
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
print('Image name: {}' .format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
利用下面的函式可以將影象和標註檔案中的點顯示出來, 方便觀察:
def show_landmarks(image, landmarks):
"""Show image with landmarks"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join('faces/', img_name)),
landmarks)
plt.show()
Dataset class
torch.utils.data.Dataset
實際上是一個用來表示資料集的虛類, 我們可以通過整合該類來定義我們自己的資料集, 在繼承時, 需要重寫以下方法:
__len__
: 讓自定義資料集支援通過len(dataset)
來返回dataset的size__getitem__
: 讓自定義資料集支援通過下標dataset[i]
來獲取第 個數據樣本.
接下來, 嘗試建立人臉姿態的自定義資料集. 我們將會在__init__
函式中讀取csv檔案, 但是會將讀取圖片的邏輯程式碼寫在__getitem__
方法中. 這麼做有助於提高記憶體使用效率, 因為我們並不需要所有的圖片同時儲存在記憶體中, 只需要在用到的時候將指定數量的圖片載入到記憶體中即可.
我們的資料集樣本將會是字典形式: {'image': image, 'landmarks':landmarks}
. 我們的資料集將會接受一個可選引數transform
, 以便可以將任何需要的圖片處理操作應用在資料樣本上. 使用transform
會使得程式碼看起來異常整潔乾淨.
class FaceLandmarksDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
# 引數:
# csv_file(string): csv標籤檔案的路徑
# root_dir(string): 所有圖片的資料夾路徑
# transform(callable, optioinal): 可選的變換操作
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks.astype("float").reshape(-1,2)
sample = {"image": image, "landmarks":landmarks}
if self.transform:
sample = self.transform(sample)
return sample
接下來, 讓我們對這個類進行初始化
face_dataset = FaceLandmarksDataset(csv_file="faces/face_.csv", root_dir="faces/")
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i]
print(i, sample["image"].shape, sample["landmarks"])
ax = plt.subplot(1,4,i+1)
plt.tight_layout()
ax.set_title("Sample")
ax.axis("off")
show_landmarks(**sample)
if i==3:
plt.show()
break
Transforms
嘗試以下三種常見的轉換操作:
- Rescale: 改變圖片的尺寸大小
- RandomCrop: 對圖片進行隨機剪裁(資料增廣技術)
- ToTensor: 將numpy圖片轉換成tensor資料
我們將會把這些操作寫成可供呼叫的類, 而不僅僅是一個簡單的函式, 這樣做的主要好處是不用每次都傳遞transform的相關引數. 為了實現可呼叫的類, 我們需要實現類的 __call__
方法, 並且根據需要實現 __init__
方法. 我們可以像下面這樣使用這些類:
tsfm = Transform(params)
transformed_sample = tsfm(sample)
具體實現如下:
class Rescale(object):
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample["image"], sample["landmarks"]
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h>w:
new_h, new_w = self.output_size*h/w, self.out_size
else:
new_h, new_w = self.output_size, self.output_size*w/h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
landmarks = landmarks*[new_w/w, new_h/h]
return {"image":img, "landmarks": landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
Compose transforms
接下來, 需要將定義好的轉換操作應用到具體的樣本上, 我們首先將特定的操作組合在一起, 然後利用torchvision.transforms.Compose
方法直接將操作應用到對應的圖片上.
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()
Iterating through the dataset
總結一下對資料取樣的過程:
- 從檔案中讀取一張圖片
- 將transforms應用到圖片上
- 由於transforms是隨機應用的, 因此起到了一定的增廣效果.
可以利用 for i in range
迴圈操作來對整個資料集進行transforms
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',root_dir='faces/',
transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break
Afterword: torchvision
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)