這篇部落格介紹torchvision.transformas。torchvision.transforms這個包中包含resize、crop等常見的data augmentation操作,基本上PyTorch中的data augmentation操作都可以通過該介面實現。該包主要包含兩個指令碼:transformas.py和functional.py,前者定義了各種data augmentation的類,在每個類中通過呼叫functional.py中對應的函式完成data augmentation操作。


import torchvision
import torch
train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                                    torch vision.Normalize([0.485
, 0.456, -.406],[0.229, 0.224, 0.225]) ]) Class custom_dataread(torch.utils.data.Dataset): def __init__(): ... def __getitem__(): # use self.transform for input image def __len__(): ... train_loader = torch.utils.data.DataLoader( custom_dataread(transform=train_augmentation), batch_size = batch_size, shuffle = True
, num_workers = workers, pin_memory = True)
主要程式碼在transformas.py指令碼中,這裡僅介紹常見的data augmentation操作,原始碼如下:
首先是匯入必須的模型,這裡比較重要的是from . import functional as F,也就是匯入了functional.py指令碼中具體的data augmentation函式。__all__列表定義了可以從外部import的函式名或類名。

from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps, ImageEnhance
    import accimage
except ImportError:
    accimage = None
import numpy as np
import numbers
import types
import collections
import warnings

from . import functional as F

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize",
"Scale", "CenterCrop", "Pad", "Lambda", "RandomCrop", 
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", 
"RandomSizedCrop", "FiveCrop", "TenCrop","LinearTransformation", 
"ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"]
class Compose(object):
    """Composes several transforms together.

        transforms (list of ``Transform`` objects): list of transforms to compose.

        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string
ToTensor類是實現:Convert a PIL Image or numpy.ndarray to tensor 的過程,在PyTorch中常用PIL庫來讀取影象資料,因此這個方法相當於搭建了PIL Image和Tensor的橋樑。另外要強調的是在做資料歸一化之前必須要把PIL Image轉成Tensor,而其他resize或crop操作則不需要。

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].

    def __call__(self, pic):
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

            Tensor: Converted image.
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'
ToPILImage顧名思義是從Tensor到PIL Image的過程,和前面ToTensor類的相反的操作。

class ToPILImage(object):
    """Convert a tensor or an ndarray to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
            1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e,
            ``int``, ``float``, ``short``).

    .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
    def __init__(self, mode=None):
        self.mode = mode

    def __call__(self, pic):
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

            PIL Image: Image converted to PIL Image.

        return F.to_pil_image(pic, self.mode)

    def __repr__(self):
        return self.__class__.__name__ + '({0})'.format(self.mode)
class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``

        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

            Tensor: Normalized Tensor image.
        return F.normalize(tensor, self.mean, self.std)

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
Resize類是對PIL Image做resize操作的,幾乎都要用到。這裡輸入可以是int,此時表示將輸入影象的短邊resize到這個int數,長邊則根據對應比例調整,影象的長寬比不變。如果輸入是個(h,w)的序列,h和w都是int,則直接將輸入影象resize到這個(h,w)尺寸,相當於force resize,所以一般最後影象的長寬比會變化,也就是影象內容被拉長或縮短。注意,在__call__方法中呼叫了functional.py指令碼中的resize函式來完成resize操作,因為輸入是PIL Image,所以resize函式基本是在呼叫Image的各種方法。如果輸入是Tensor,則對應函式基本是在呼叫Tensor的各種方法,這就是functional.py中的主要內容。

class Resize(object):
    """Resize the input PIL Image to the given size.

        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
            img (PIL Image): Image to be scaled.

            PIL Image: Rescaled image.
        return F.resize(img, self.size, self.interpolation)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)
class CenterCrop(object):
    """Crops the given PIL Image at the center.

        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
            self.size = size

    def __call__(self, img):
            img (PIL Image): Image to be cropped.

            PIL Image: Cropped image.
        return F.center_crop(img, self.size)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)
相比前面的CenterCrop,這個RandomCrop更常用,差別就在於crop時的中心點座標是隨機的,並不是輸入影象的中心點座標,因此基本上每次crop生成的影象都是有差異的。就是通過 i = random.randint(0, h - th)和 j = random.randint(0, w - tw)兩行生成一個隨機中心點的橫縱座標。注意到在__call__中最後是呼叫了F.crop(img, i, j, h, w)來完成crop操作,其實前面CenterCrop中雖然是呼叫 F.center_crop(img, self.size),但是在F.center_crop()函式中只是先計算了中心點座標,最後還是呼叫F.crop(img, i, j, h, w)完成crop操作。

class RandomCrop(object):
    """Crop the given PIL Image at a random location.

        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders

    def __init__(self, size, padding=0):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
            self.size = size
        self.padding = padding

    def get_params(img, output_size):
        """Get parameters for ``crop`` for a random crop.

            img (PIL Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        w, h = img.size
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w

        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw

    def __call__(self, img):
            img (PIL Image): Image to be cropped.

            PIL Image: Cropped image.
        if self.padding > 0:
            img = F.pad(img, self.padding)

        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL Image randomly with a probability of 0.5."""

    def __call__(self, img):
            img (PIL Image): Image to be flipped.

            PIL Image: Randomly flipped image.
        if random.random() < 0.5:
            return F.hflip(img)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '()'
class RandomVerticalFlip(object):
    """Vertically flip the given PIL Image randomly with a probability of 0.5."""

    def __call__(self, img):
            img (PIL Image): Image to be flipped.

            PIL Image: Randomly flipped image.
        if random.random() < 0.5:
            return F.vflip(img)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '()'
RandomResizedCrop類也是比較常用的,個人非常喜歡用。前面不管是CenterCrop還是RandomCrop,在crop的時候其尺寸是固定的,而這個類則是random size的crop。該類主要用到3個引數:size、scale和ratio,總的來講就是先做crop(用到scale和ratio),再resize到指定尺寸(用到size)。做crop的時候,其中心點座標和長寬是由get_params方法得到的,在get_params方法中主要用到兩個引數:scale和ratio,首先在scale限定的數值範圍內隨機生成一個數,用這個數乘以輸入影象的面積作為crop後圖像的面積;然後在ratio限定的數值範圍內隨機生成一個數,表示長寬的比值,根據這兩個值就可以得到crop影象的長寬了。至於crop影象的中心點座標,也是類似RandomCrop類一樣是隨機生成的。

class RandomResizedCrop(object):
    """Crop the given PIL Image to random size and aspect ratio.

    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

        size: expected output size of each edge
        scale: range of size of the origin size cropped
        ratio: range of aspect ratio of the origin aspect ratio cropped
        interpolation: Default: PIL.Image.BILINEAR

    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
        self.size = (size, size)
        self.interpolation = interpolation
        self.scale = scale
        self.ratio = ratio

    def get_params(img, scale, ratio):
        """Get parameters for ``crop`` for a random sized crop.

            img (PIL Image): Image to be cropped.
            scale (tuple): range of size of the origin size cropped
            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped

            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(*scale) * area
            aspect_ratio = random.uniform(*ratio)

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w

        # Fallback
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
        return i, j, w, w

    def __call__(self, img):
            img (PIL Image): Image to be flipped.

            PIL Image: Randomly cropped and resize image.
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)
FiveCrop類,顧名思義就是從一張輸入影象中crop出5張指定size的影象,這5張影象包括4個角的影象和一個center crop的影象。曾在TSN演算法的看到過這種用法。

class FiveCrop(object):
    """Crop the given PIL Image into four corners and the central crop

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with

         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.

         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
        return F.five_crop(img, self.size)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)
class TenCrop(object):
    """Crop the given PIL Image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default)

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with

        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
        vertical_flip(bool): Use vertical flipping instead of horizontal

         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops

    def __init__(self, size, vertical_flip=False):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
        self.vertical_flip = vertical_flip

    def __call__(self, img):
        return F.ten_crop(img, self.size, self.vertical_flip)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)
class LinearTransformation(object):
    """Transform a tensor image with a square transformation matrix computed

    Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
    product with the transformation matrix and reshape the tensor to its
    original shape.

    - whitening: zero-center the data, compute the data covariance matrix
                 [D x D] with np.dot(X.T, X), perform SVD on this matrix and
                 pass it as transformation_matrix.

        transformation_matrix (Tensor): tensor [D x D], D = C x H x W

    def __init__(self, transformation_matrix):
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
        self.transformation_matrix = transformation_matrix

    def __call__(self, tensor):
            tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

            Tensor: Transformed image.
        if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
            raise ValueError("tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(*tensor.size()) +
        flat_tensor = tensor.view(1, -1)
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
        tensor = transformed_tensor.view(tensor.size())
        return tensor

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
        return format_string
ColorJitter類也比較常用,主要是修改輸入影象的4大引數值:brightness, contrast and saturation,hue,也就是亮度,對比度,飽和度和色度。可以根據註釋來合理設定這4個引數。

class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation of an image.

        brightness (float): How much to jitter brightness. brightness_factor
            is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
        contrast (float): How much to jitter contrast. contrast_factor
            is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
        saturation (float): How much to jitter saturation. saturation_factor
            is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
        hue(float): How much to jitter hue. hue_factor is chosen uniformly from
            [-hue, hue]. Should be >=0 and <= 0.5.
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        transforms = []
        if brightness > 0:
            brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
            transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))

        if contrast > 0:
            contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
            transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))

        if saturation > 0:
            saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
            transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))

        if hue > 0:
            hue_factor = np.random.uniform(-hue, hue)
            transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))

        transform = Compose(transforms)

        return transform

    def __call__(self, img):
            img (PIL Image): Input image.

            PIL Image: Color jittered image.
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return transform(img)

    def __repr__(self):
        return self.__class__.__name__ + '()'
RandomRotation類是隨機旋轉輸入影象,也比較常用,具體引數可以看註釋,在F.rotate()中主要是呼叫PIL Image的rotate方法。

class RandomRotation(object):
    """Rotate the image by angle.

        degrees (sequence or float or int): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
            An optional resampling filter.
            See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
        center (2-tuple, optional): Optional center of rotation.
            Origin is the upper left corner.
            Default is the center of the image.

    def __init__(self, degrees, resample=False, expand=False, center=None):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
            if len(degrees) != 2:
                raise ValueError("If degrees is a sequence, it must be of len 2.")
            self.degrees = degrees

        self.resample = resample
        self.expand = expand
        self.center = center

    def get_params(degrees):
        """Get parameters for ``rotate`` for a random rotation.

            sequence: params to be passed to ``rotate`` for random rotation.
        angle = np.random.uniform(degrees[0], degrees[1])

        return angle

    def __call__(self, img):
            img (PIL Image): Image to be rotated.

            PIL Image: Rotated image.

        angle = self.get_params(self.degrees)

        return F.rotate(img, angle, self.resample, self.expand, self.center)

    def __repr__(self):
        return self.__class__.__name__ + '(degrees={0})'.format(self.degrees)
class Grayscale(object):
    """Convert image to grayscale.

        num_output_channels (int): (1 or 3) number of channels desired for output image

        PIL Image: Grayscale version of the input.
        - If num_output_channels == 1 : returned image is single channel
        - If num_output_channels == 3 : returned image is 3 channel with r == g == b


    def __init__(self, num_output_channels=1):
        self.num_output_channels = num_output_channels

    def __call__(self, img):
            img (PIL Image): Image to be converted to grayscale.

            PIL Image: Randomly grayscaled image.
        return F.to_grayscale(img, num_output_channels=self.num_output_channels)

    def __repr__(self):
        return self.__class__.__name__ + '()'
class RandomGrayscale(object):
    """Randomly convert image to grayscale with a probability of p (default 0.1).

        p (float): probability that image should be converted to grayscale.

        PIL Image: Grayscale version of the input image with probability p and unchanged
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b


    def __init__(self, p=0.1):
        self.p = p

    def __call__(self, img):
            img (PIL Image): Image to be converted to grayscale.

            PIL Image: Randomly grayscaled image.
        num_output_channels = 1 if img.mode == 'L' else 3
        if random.random() < self.p:
            return F.to_grayscale(img, num_output_channels=num_output_channels)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '()'
