pytorch—ImageFolder/自定義類 讀取圖片資料—Transform資料轉換
文章目錄
執行環境安裝 Anaconda | python ==3.6.6
conda install pytorch -c pytorch pip install config pip install tqdm #包裝迭代器,顯示進度條 pip install torchvision pip install scikit-image
一、torchvision 影象資料讀取 [0, 1]
import torchvision.transforms as transforms
transforms 模組提供了一般的影象轉換操作類。
class torchvision.transforms.ToTensor
功能:
把shape=(H x W x C) 的畫素值為 [0, 255] 的 PIL.Image 和 numpy.ndarray
轉換成shape=(C x H x W)的畫素值範圍為[0.0, 1.0]
的 torch.FloatTensor。
class torchvision.transforms.Normalize(mean, std)
功能:
此轉換類作用於torch.*Tensor。給定均值(R, G, B)和標準差(R, G, B),用公式channel = (channel - mean) / std進行規範化。
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from PIL import Image
img_path = "./data/timg.jpg"
# 引入transforms.ToTensor()功能: range [0, 255] -> [0.0,1.0]
transform1 = transforms.Compose([transforms.ToTensor()])
# 直接讀取:numpy.ndarray
img = cv2.imread(img_path)
print("img = ", img[0]) #只輸出其中一個通道
print("img.shape = ", img.shape)
# 歸一化,轉化為numpy.ndarray並顯示
img1 = transform1(img)
img2 = img1.numpy()*255
img2 = img2.astype('uint8')
img2 = np.transpose(img2 , (1,2,0))
print("img1 = ", img1)
cv2.imshow('img2 ', img2 )
cv2.waitKey()
# PIL 讀取影象
img = Image.open(img_path).convert('RGB') # 讀取影象
img2 = transform1(img) # 歸一化到 [0.0,1.0]
print("img2 = ",img2) #轉化為PILImage並顯示
img_2 = transforms.ToPILImage()(img2).convert('RGB')
print("img_2 = ",img_2)
img_2.show()
從上到下依次輸出:---------------------------------------------
img = [[197 203 202]
[195 203 202]
...
[200 208 207]
[200 208 207]]
img.shape = (362, 434, 3)
img1 = tensor([[[0.7725, 0.7647, 0.7686, ..., 0.7804, 0.7843, 0.7843],
[0.7765, 0.7725, 0.7686, ..., 0.7686, 0.7608, 0.7569],
[0.7843, 0.7725, 0.7686, ..., 0.7725, 0.7686, 0.7569],
...,
img_transform = tensor([[[0.7922, 0.7922, 0.7961, ..., 0.8078, 0.8118, 0.8118],
[0.7961, 0.8000, 0.7961, ..., 0.7922, 0.7882, 0.7843],
[0.8039, 0.8000, 0.7961, ..., 0.8118, 0.8039, 0.7922],
...,
transforms.Compose 歸一化到 [-1.0, 1.0 ]
transform2 = transforms.Compose([transforms.ToTensor()])
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))])
二、torchvision 的 Transform
在深度學習時關於影象的資料讀取:由於Tensorflow不支援與numpy的無縫切換,導致難以使用現成的pandas等格式化資料讀取工具,造成了很多不必要的麻煩,而pytorch解決了這個問題。
pytorch自定義讀取資料和進行Transform的部分請見文件:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
但是按照文件中所描述所完成的自定義Dataset只能夠使用自定義的Transform步驟,而torchvision包中已經給我們提供了很多影象transform步驟的實現,為了使用這些已經實現的Transform步驟,我們可以使用如下方法定義Dataset:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class FaceLandmarkDataset(Dataset):
def __len__(self) -> int:
return len(self.landmarks_frame)
def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:
super().__init__()
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __getitem__(self, index:int):
img_name = self.landmarks_frame.ix[index, 0]
img_path = os.path.join('./faces', img_name)
with Image.open(img_path) as img:
image = img.convert('RGB')
landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')
landmarks = np.reshape(landmarks,newshape=(-1,2))
if self.transform is not None:
image = self.transform(image)
return image, landmarks
########################以上為資料讀取類(返回:image,landmarks)###############################
trans = transforms.Compose(transforms = [transforms.RandomSizedCrop(size=128),
transforms.ToTensor()])
face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces', transform= trans)
loader = DataLoader(dataset = face_dataset,
batch_size=4,
shuffle=True,
num_workers=4)
三、讀取影象資料類
3.1 class torchvision.datasets.ImageFolder 預設讀取影象資料方法:
__init__
( 初始化)classes, class_to_idx = find_classes(root)
:得到分類的類別名(classes)和類別名與數字類別的對映關係字典(class_to_idx)
其中 classes (list): List of the class names.
其中 class_to_idx (dict): Dict with items (class_name, class_index).imgs = make_dataset(root, class_to_idx)
得到imags列表。
其中 imgs (list): List of (image path, class_index) tuples
每個值是一個tuple,每個tuple包含兩個元素:影象路徑和標籤
__getitem__
(影象獲取)path, target = self.imgs[index]
獲取影象(路徑,標籤)img = self.loader(path)
資料讀取。img = self.transform(img)
資料、標籤 轉換成 tensortarget = self.target_transform(target)
__len__
( 資料集數量)return len(self.imgs)
class ImageFolder(data.Dataset):
"""預設影象資料目錄結構
root
.
├──dog
| ├──001.png
| ├──002.png
| └──...
└──cat
| ├──001.png
| ├──002.png
| └──...
└──...
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
"""
index (int): Index
Returns:tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
影象獲取 __getitem__
中,self.loader(path) 採用的是default_loader,如下
def pil_loader(path): # 一般採用pil_loader函式。
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
3.2 自定義資料讀取方法
PyTorch中和資料讀取相關的類都要繼承一個基類:torch.utils.data.Dataset。
故需要改寫其中的 __init__、__len__、__getitem__
等三個方法即可。
__init__()
初始化傳入引數:- img_path 裡面為所有影象資料(包括訓練和測試)
txt_path 裡面有 train.txt和val.txt兩個檔案:txt檔案中每行都是影象路徑,tab鍵,標籤。 - 其中 self.img_name 和 self.img_label 的讀取方式就跟你資料的存放方式有關(需要調整的地方)
- img_path 裡面為所有影象資料(包括訓練和測試)
__getitem__()
依然採用default_loader方法來讀取影象。Transform
中將每張影象都封裝成 Tensor
class customData(Dataset):
def __init__(self, img_path, txt_path, dataset = '',data_transforms=None, loader = default_loader):
with open(txt_path) as input_file:
"""
關於json檔案解析:
https://blog.csdn.net/wsp_1138886114/article/details/83302339
txt檔案解析如下,具體文字解析具體分析,沒有定數
"""
lines = input_file.readlines()
self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
self.data_transforms = data_transforms
self.dataset = dataset
self.loader = loader
def __len__(self):
return len(self.img_name)
def __getitem__(self, item):
img_name = self.img_name[item]
label = self.img_label[item]
img = self.loader(img_name)
if self.data_transforms is not None:
try:
img = self.data_transforms[self.dataset](img)
except:
print("Cannot transform image: {}".format(img_name))
return img, label
#####################以上為影象資料讀取,返回(img, label)#########################
# 保證image_datasets與torchvision.datasets.ImageFolder類返回的資料型別一樣
image_datasets = {x: customData(img_path='/ImagePath',
txt_path=('/TxtFile/' + x + '.txt'),
data_transforms=data_transforms,
dataset=x) for x in ['train', 'val']}
#用torch.utils.data.DataLoader類,將這個batch的影象資料和標籤都分別封裝成Tensor。
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True) for x in ['train', 'val']}
# 模型儲存
torch.save(model, 'output/resnet_epoch{}.pkl'.format(epoch))
https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/#torchutilsdata
鳴謝
https://www.cnblogs.com/denny402/p/5096001.html
https://blog.csdn.net/VictoriaW/article/details/72822005
https://blog.csdn.net/hao5335156/article/details/80593349
https://blog.csdn.net/u014380165/article/details/78634829