1. 程式人生 > 其它 >pytorch資料集和資料處理部分dataset自定義、繼承

pytorch資料集和資料處理部分dataset自定義、繼承

https://blog.csdn.net/zhenaoxi1077/article/details/80953227  

一、資料載入

在Pytorch 中,資料載入可以通過自己定義的資料集物件來實現。資料集物件被抽象為Dataset類,實現自己定義的資料集需要繼承Dataset,並實現兩個Python魔法方法。

__getitem__: 返回一條資料或一個樣本。   obj[index]等價於obj.__getitem__(index).

__len__: 返回樣本的數量。len(obj)等價於obj.__len__().
import torch as t
from torch.utils import
data import os from PIL import Image import numpy as np class DogCat(data.Dataset): def __init__(self,root): imgs=os.listdir(root) #所有圖片的絕對路徑 #這裡不實際載入圖片,只是指定路徑,當呼叫__getitem__時才會真正讀圖片 self.imgs=[os.path.join(root, img) for img in imgs] def __getitem__(self, index): img_path
=self.imgs[index] #dog->1, cat->0 label=1 if 'dog' in img_path.split("/")[-1] else 0 pil_img=Image.open(img_path) array=np.asarray(pil_img) data=t.from_numpy(array) return data,label def __len__(self): return len(self.image) dataset=DogCat('
N:/百度網盤/kaggle/DogCat') img,label=dataset[0]#相當於呼叫dataset.__getitem__(0) for img,label in dataset: print(img.size(),img.float().mean(),label)

二、資料處理transforms

ytorch提供了torchvision。它是一個視覺工具包,提供了很多視覺影象處理的工具。
其中transforms模組提供了對PIL Image物件和Tensor物件的常用操作。

對PIL Image的常見操作如下:

(1)Scale/Resize: 調整尺寸,長寬比保持不變; #Resize
(2)CenterCrop、RandomCrop、RandomSizedCrop:裁剪圖片;
(3)Pad: 填充;
(4)ToTensor: 將PIL Image物件轉換成Tensor,會自動將【0,255】歸一化至【0,1】。
(5)對Tensor的常見操作如下:Normalize: 標準化,即減均值,除以標準差;ToPILImage:將Tensor轉為PIL Image.

如果要對圖片進行多個操作,可通過Compose將這些操作拼接起來,類似於nn.Sequential.這些操作定義之後是以物件的形式存在,真正使用時需要呼叫它的__call__方法,類似於nn.Mudule.

例如:要將圖片調整為224*224,首先應構建操作trans=Scale((224,224)),然後呼叫trans(img).

import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transforms=T.Compose([
    T.Resize(224),  #縮放圖片(Image),保持長寬比不變,最短邊為224畫素
    T.CenterCrop(224), #從圖片中間裁剪出224*224的圖片
    T.ToTensor(), #將圖片Image轉換成Tensor,歸一化至【0,1】
    T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])  #標準化至【-1,1】,規定均值和方差
])

class DogCat(data.Dataset):
    def __init__(self,root, transforms=None):
        imgs=os.listdir(root)
        self.imgs=[os.path.join(root, img) for img in imgs]
        self.transforms=transforms

    def __getitem__(self, index):
        img_path=self.imgs[index]
        #dog->1, cat->0
        label=1 if 'dog' in img_path.split("/")[-1] else 0
        data=Image.open(img_path)
        if self.transforms:
            data=self.transforms(data)
        return data,label

    def __len__(self):
        return len(self.imgs)        

dataset=DogCat('N:/百度網盤/kaggle/DogCat/', transforms=transforms)
img,label=dataset[0]#相當於呼叫dataset.__getitem__(0)
for img,label in dataset:
    print(img.size(),label)

三、ImageFolder

下面介紹一個會經常使用到的Dataset——ImageFolder,它的實現和上述DogCat很相似。

 

四、DataLoader

DataLoader載入資料

Dateset只負責資料的抽象,一次呼叫__getitem__只返回一個樣本。
在訓練神經網路時,是對一個batch的資料進行操作,同時還要進行shuffle和並行加速等。
對此,pytorch提供了DataLoader幫助我們實現這些功能。

dataloader=DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)

dataloader是一個可以迭代的物件

 

五、sampler取樣模組