pytorch學習筆記(2)-使用自定義txt檔案讀取資料
阿新 • • 發佈:2018-12-16
import os import torch from PIL import Image from torch.utils.data import Dataset from torchvision import datasets, models, transforms class dataLoader(Dataset): def __init__(self, path, listName, dataset = '', data_transforms = None, loader = None): self.path = path self.listName = listName self.images = [os.path.join(self.path, line.strip().split()[0]) for line in open(self.listName)] self.labels = [int(line.strip().split()[1]) for line in open(self.listName)] self.data_transforms = data_transforms self.dataset = dataset if loader: self.loader = loader else: self.loader = self.default_loader def default_loader(self, imageName): try: image = Image.open(imageName) return image.convert('RGB') except: print("Cannot read image", path) def __len__(self): return len(self.images) def __getitem__(self, item): imageName = self.images[item] label = self.labels[item] image = self.loader(imageName) if self.data_transforms is not None: try: image = self.data_transforms[self.dataset](image) except: print("Cannot transform image", imageName) return image, label class dataAugmentation(): def __init__(self): self.data_transforms = { "trainImages": transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), "testImages": transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } if __name+__ == "__main__": augmentation = dataAugmentation() data = dataLoader(dataPath, dataList, dataset = dataset, data_transforms = augmentation.data_transforms) dataloaders = torch.utils.data.DataLoader(data, batch_size = 100, shufle = False, num_workers = 8)#執行緒數 # 此後在dataloaders生成器中去一個batch一個batch選取images和labels