1. 程式人生 > 其它 >Pytorch資料讀取機制(DataLoader)與影象預處理模組(transforms)

Pytorch資料讀取機制(DataLoader)與影象預處理模組(transforms)

Pytorch資料讀取機制(DataLoader)與影象預處理模組(transforms)

1.DataLoader

torch.utils.data.DataLoader():構建可迭代的資料裝載器, 訓練的時候,每一個for迴圈,每一次iteration,就是從DataLoader中獲取一個batch_size大小的資料的。

Dataloader()引數:

  • dataset: Dataset類,決定資料從哪讀取(資料路徑)以及如何讀取(做哪些預處理)
  • batchsize: 批大小
  • num_works: 是否採用多程序讀取機制
  • shuffle: 每一個epoch是否亂序
  • drop_last: 當樣本數不能被batchsize整除時,是否捨棄最後一批資料。

2. Dataset

torch.utils.data.Dataset():Dataset抽象類, 所有自定義的Dataset都需要繼承它,並且必須複寫__getitem__()這個類方法。

__getitem__方法的是Dataset的核心,作用是接收一個索引, 返回一個樣本, 看上面的函式,引數裡面接收index,然後我們需要編寫究竟如何根據這個索引去讀取我們的資料部分。

2.1 ImageFolder

torchvision已經預先實現了常用的Dataset, 其他預先實現的有: torchvision.datasets.CIFAR10, 可以讀取CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等資料集。

ImageFolder假設所有的檔案按資料夾儲存,每個資料夾下儲存同一個類別的圖片,資料夾名為類名,其建構函式如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

引數:

  • root: 圖片路徑
  • transform: 對PIL Image進行的轉換操作,transform的輸入是使用loader讀取圖片的返回物件
  • target_transform:對label的轉換
  • loader:給定路徑後如何讀取圖片,預設讀取為RGB格式的PIL Image物件

示例:

資料夾格式:

train_path = r'datasets/myDataSet/train'

預處理格式:

train_transform = transforms.Compose([
    transforms.Resize((40,40)),
    transforms.RandomCrop(40,padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225],)
])

dataset:

trainset = ImageFolder(train_path,transform = train_transform)
# print(trainset[30]) # 元組型別,第30號圖片的(畫素資訊,label)

Data.DataLoader:

train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=False)
for i,(img, target) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一個batch圖片對應的label

2.2

class myDataset(Data.Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.data_info = self.get_img_info(path)
        self.label = []
        for i in range(len(self.data_info)):
            self.label.append(list(self.data_info[i])[1])

    def __getitem__(self, idx):
        path_img = self.data_info[idx][0]
        label = self.label[idx]
        img = Image.open(path_img).convert('RGB')  # 0~255
        if self.transform is not None:
            img = self.transform(img)  # 在這裡做transform,轉為tensor等等
        return img, label, idx

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍歷類別
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                # 遍歷圖片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = int(sub_dir)
                    data_info.append((path_img, int(label)))
        return data_info

trainset = myDataset(train_path, train_transform)

train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=True)
for i,(img, target, index) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一個batch的圖片對應的label
    print(index) #  一個batch的圖片在資料集中對應的index

s