1. 程式人生 > >PyTorch官方教程(二)-DataLoadingAndProcessing

PyTorch官方教程(二)-DataLoadingAndProcessing

對於一個新的機器/深度學習任務, 大量的時間都會花費在資料準備上. PyTorch提供了多種輔助工具來幫助使用者更方便的處理和載入資料. 本示例主要會用到以下兩個包:

  • scikit-image: 用於讀取和處理圖片
  • pandas: 用於解析csv檔案

匯入下面的包

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import
matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils # Ignore warnings import warnings warnings.filterwarnings("ignore") plt.ion() # interactive mode

本示例使用的是人臉姿態的資料集, 資料集的標註資訊是由68個landmark點組成的, csv檔案的格式如下所示:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

利用如下程式碼可以快速的讀取CSV檔案裡面的標註資訊, 並且將其轉換成 (N,2) 的陣列形式, 其中, N 為 landmarks 點的個數

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'
.format(img_name)) print('Landmarks shape: {}'.format(landmarks.shape)) print('First 4 Landmarks: {}'.format(landmarks[:4]))

利用下面的函式可以將影象和標註檔案中的點顯示出來, 方便觀察:

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('faces/', img_name)),
               landmarks)
plt.show()

Dataset class

torch.utils.data.Dataset實際上是一個用來表示資料集的虛類, 我們可以通過整合該類來定義我們自己的資料集, 在繼承時, 需要重寫以下方法:

  • __len__: 讓自定義資料集支援通過len(dataset)來返回dataset的size
  • __getitem__: 讓自定義資料集支援通過下標dataset[i]來獲取第 i i 個數據樣本.

接下來, 嘗試建立人臉姿態的自定義資料集. 我們將會在__init__函式中讀取csv檔案, 但是會將讀取圖片的邏輯程式碼寫在__getitem__方法中. 這麼做有助於提高記憶體使用效率, 因為我們並不需要所有的圖片同時儲存在記憶體中, 只需要在用到的時候將指定數量的圖片載入到記憶體中即可.

我們的資料集樣本將會是字典形式: {'image': image, 'landmarks':landmarks}. 我們的資料集將會接受一個可選引數transform, 以便可以將任何需要的圖片處理操作應用在資料樣本上. 使用transform會使得程式碼看起來異常整潔乾淨.

class FaceLandmarksDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        # 引數:
        # csv_file(string): csv標籤檔案的路徑
        # root_dir(string): 所有圖片的資料夾路徑
        # transform(callable, optioinal): 可選的變換操作
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks.astype("float").reshape(-1,2)
        sample = {"image": image, "landmarks":landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

接下來, 讓我們對這個類進行初始化

face_dataset = FaceLandmarksDataset(csv_file="faces/face_.csv", root_dir="faces/")
fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]
    print(i, sample["image"].shape, sample["landmarks"])
    ax = plt.subplot(1,4,i+1)
    plt.tight_layout()
    ax.set_title("Sample")
    ax.axis("off")
    show_landmarks(**sample)
    if i==3:
        plt.show()
        break

Transforms

嘗試以下三種常見的轉換操作:

  • Rescale: 改變圖片的尺寸大小
  • RandomCrop: 對圖片進行隨機剪裁(資料增廣技術)
  • ToTensor: 將numpy圖片轉換成tensor資料

我們將會把這些操作寫成可供呼叫的類, 而不僅僅是一個簡單的函式, 這樣做的主要好處是不用每次都傳遞transform的相關引數. 為了實現可呼叫的類, 我們需要實現類的 __call__ 方法, 並且根據需要實現 __init__ 方法. 我們可以像下面這樣使用這些類:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

具體實現如下:

class Rescale(object):

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample["image"], sample["landmarks"]

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h>w:
                new_h, new_w = self.output_size*h/w, self.out_size
            else:
                new_h, new_w = self.output_size, self.output_size*w/h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        landmarks = landmarks*[new_w/w, new_h/h]

        return {"image":img, "landmarks": landmarks}

class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

Compose transforms

接下來, 需要將定義好的轉換操作應用到具體的樣本上, 我們首先將特定的操作組合在一起, 然後利用torchvision.transforms.Compose方法直接將操作應用到對應的圖片上.

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()

Iterating through the dataset

總結一下對資料取樣的過程:

  • 從檔案中讀取一張圖片
  • 將transforms應用到圖片上
  • 由於transforms是隨機應用的, 因此起到了一定的增廣效果.

可以利用 for i in range迴圈操作來對整個資料集進行transforms

transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',root_dir='faces/',
                    transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break

Afterword: torchvision

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
        ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                    transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                            batch_size=4, shuffle=True,
                            num_workers=4)