Pytorch原始碼解讀-torchvision.transforms
torchvision.transforms對於使用Pytorch的人肯定不陌生,我也用了Pytorch但是對transform卻不是真正掌握,圖片的預處理對網路的效能十分重要,所以打算仔細看看pytorch的原始碼。
Transforms組成
Transforms are common image transforms. They can be chained together using
Compose
Transforms是常用的一些影象變換操作,可以用Compose
將這些變換組合在一起
Transfroms由5部分組成
- Transforms on PIL Image
- Transforms on torch.Tensor
- Conversion Transforms
- Generic Transforms
- Functional Transforms
其中,前兩部分Transforms on PIL Image
和Transforms on torch.*Tensor
用的比較多。下面會詳細說明,先說一下compose
compose
Compose
是一個類用來組合所有的變換操作。
class torchvision.transforms.Compose(transforms)
使用方法,應該很常見:
transforms.Compose([ transforms.CenterCrop(10), transforms.ToTensor(), ])
Transforms on PIL Image
函式 | 用途 | 用法 |
---|---|---|
CenterCrop(size) | crops the given PIL image at the center 從圖片中心剪裁一個size大小的圖片 |
CenterCrop(160) |
GrayScale(num) | convert image to grayscale 將圖片變成灰度圖 |
GrayScale(1)或者 GrayScale(3) |
RandomCrop(size) | crop the given PIL image at a random location 在給定的圖片中隨機的剪裁一張size大小的圖。 這個在訓練中經常用到,算作一種資料增強的手段 | RandomCrop(224) |
RandomHorizontaFlip(p) | Horizontally flip the given image randomly with a given probability 按照給定概率隨機的對圖片進行水平映象 這一條也經常用於訓練 |
RandomHorizontaFlip(0.5) |
Resize(size) | Resize the input PIL image to the given size 一般用在驗證集和測試集 |
Resize(224) |
Transforms on torch.Tensor
Tensor資料的處理只有一個函式Normalize(mean,std)
Normalize函式幾乎都會用到,需要注意的是:
只能對Tensor資料進行Normalize,不能對PIL image用,即在用transform.Normalize()
前先使用下面的transform.ToTensor()
用法:
transform.Normalize([127.5,127.5,127.5],[128,128,128])
對指定通道的畫素減去均值除以方差,一般用於圖片的歸一化,均值和方差的取值需要注意。
Conversion Transforms
這部分包括兩個變換:
- class torchvision.transforms.ToPILImage(mode=None)
將Tensor形式變成PIL Image形式,用的不多 - class torchvision.transforms.ToTensor
將PIL Image形式變成Tensor形式,用的很多,只要你用到gpu都會涉及到。
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
將一個image或者numpy.ndarrya形式的圖片,圖片的取值範圍[0,255]變成一個tensor,tensor的取值為[0,1],並且通道順序由HWC變成CHW。
需要注意兩點,圖片的取值範圍和通道順序。取值範圍會影響到接下來的歸一化normalize
。通道順序從高x寬x通道變成通道x高x寬,簡單說就是從224x224x3變成3x224x224