Pytorch資料讀取機制(DataLoader)與影象預處理模組(transforms)
阿新 • • 發佈:2021-07-20
Pytorch資料讀取機制(DataLoader)與影象預處理模組(transforms)
1.DataLoader
torch.utils.data.DataLoader()
:構建可迭代的資料裝載器, 訓練的時候,每一個for迴圈,每一次iteration,就是從DataLoader中獲取一個batch_size大小的資料的。
Dataloader()引數:
- dataset: Dataset類,決定資料從哪讀取(資料路徑)以及如何讀取(做哪些預處理)
- batchsize: 批大小
- num_works: 是否採用多程序讀取機制
- shuffle: 每一個epoch是否亂序
- drop_last: 當樣本數不能被batchsize整除時,是否捨棄最後一批資料。
2. Dataset
torch.utils.data.Dataset()
:Dataset抽象類, 所有自定義的Dataset都需要繼承它,並且必須複寫__getitem__()
這個類方法。
__getitem__
方法的是Dataset的核心,作用是接收一個索引, 返回一個樣本, 看上面的函式,引數裡面接收index,然後我們需要編寫究竟如何根據這個索引去讀取我們的資料部分。
2.1 ImageFolder
torchvision已經預先實現了常用的Dataset, 其他預先實現的有: torchvision.datasets.CIFAR10
, 可以讀取CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等資料集。
ImageFolder假設所有的檔案按資料夾儲存,每個資料夾下儲存同一個類別的圖片,資料夾名為類名,其建構函式如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
引數:
- root: 圖片路徑
- transform: 對PIL Image進行的轉換操作,transform的輸入是使用loader讀取圖片的返回物件
- target_transform:對label的轉換
- loader:給定路徑後如何讀取圖片,預設讀取為RGB格式的PIL Image物件
示例:
資料夾格式:
train_path = r'datasets/myDataSet/train'
預處理格式:
train_transform = transforms.Compose([
transforms.Resize((40,40)),
transforms.RandomCrop(40,padding=4),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],
[0.229,0.224,0.225],)
])
dataset:
trainset = ImageFolder(train_path,transform = train_transform)
# print(trainset[30]) # 元組型別,第30號圖片的(畫素資訊,label)
Data.DataLoader:
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=False)
for i,(img, target) in enumerate(train_loader):
print(i)
print(img.shape) # (batchsize, channel, H, W)
print(target.shape) # (batch)
print(target) # 一個batch圖片對應的label
2.2
class myDataset(Data.Dataset):
def __init__(self, path, transform):
self.path = path
self.transform = transform
self.data_info = self.get_img_info(path)
self.label = []
for i in range(len(self.data_info)):
self.label.append(list(self.data_info[i])[1])
def __getitem__(self, idx):
path_img = self.data_info[idx][0]
label = self.label[idx]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在這裡做transform,轉為tensor等等
return img, label, idx
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍歷類別
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍歷圖片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = int(sub_dir)
data_info.append((path_img, int(label)))
return data_info
trainset = myDataset(train_path, train_transform)
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=True)
for i,(img, target, index) in enumerate(train_loader):
print(i)
print(img.shape) # (batchsize, channel, H, W)
print(target.shape) # (batch)
print(target) # 一個batch的圖片對應的label
print(index) # 一個batch的圖片在資料集中對應的index
s