pytorch資料集和資料處理部分dataset自定義、繼承
https://blog.csdn.net/zhenaoxi1077/article/details/80953227
一、資料載入
在Pytorch 中,資料載入可以通過自己定義的資料集物件來實現。資料集物件被抽象為Dataset類,實現自己定義的資料集需要繼承Dataset,並實現兩個Python魔法方法。
__getitem__: 返回一條資料或一個樣本。 obj[index]等價於obj.__getitem__(index). __len__: 返回樣本的數量。len(obj)等價於obj.__len__().
import torch as t from torch.utils importdata import os from PIL import Image import numpy as np class DogCat(data.Dataset): def __init__(self,root): imgs=os.listdir(root) #所有圖片的絕對路徑 #這裡不實際載入圖片,只是指定路徑,當呼叫__getitem__時才會真正讀圖片 self.imgs=[os.path.join(root, img) for img in imgs] def __getitem__(self, index): img_path=self.imgs[index] #dog->1, cat->0 label=1 if 'dog' in img_path.split("/")[-1] else 0 pil_img=Image.open(img_path) array=np.asarray(pil_img) data=t.from_numpy(array) return data,label def __len__(self): return len(self.image) dataset=DogCat('N:/百度網盤/kaggle/DogCat') img,label=dataset[0]#相當於呼叫dataset.__getitem__(0) for img,label in dataset: print(img.size(),img.float().mean(),label)
二、資料處理transforms
ytorch提供了torchvision。它是一個視覺工具包,提供了很多視覺影象處理的工具。
其中transforms模組提供了對PIL Image物件和Tensor物件的常用操作。
對PIL Image的常見操作如下:
(1)Scale/Resize: 調整尺寸,長寬比保持不變; #Resize
(2)CenterCrop、RandomCrop、RandomSizedCrop:裁剪圖片;
(3)Pad: 填充;
(4)ToTensor: 將PIL Image物件轉換成Tensor,會自動將【0,255】歸一化至【0,1】。
(5)對Tensor的常見操作如下:Normalize: 標準化,即減均值,除以標準差;ToPILImage:將Tensor轉為PIL Image.
如果要對圖片進行多個操作,可通過Compose將這些操作拼接起來,類似於nn.Sequential.這些操作定義之後是以物件的形式存在,真正使用時需要呼叫它的__call__方法,類似於nn.Mudule.
例如:要將圖片調整為224*224,首先應構建操作trans=Scale((224,224)),然後呼叫trans(img).
import os from PIL import Image import numpy as np from torchvision import transforms as T transforms=T.Compose([ T.Resize(224), #縮放圖片(Image),保持長寬比不變,最短邊為224畫素 T.CenterCrop(224), #從圖片中間裁剪出224*224的圖片 T.ToTensor(), #將圖片Image轉換成Tensor,歸一化至【0,1】 T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #標準化至【-1,1】,規定均值和方差 ]) class DogCat(data.Dataset): def __init__(self,root, transforms=None): imgs=os.listdir(root) self.imgs=[os.path.join(root, img) for img in imgs] self.transforms=transforms def __getitem__(self, index): img_path=self.imgs[index] #dog->1, cat->0 label=1 if 'dog' in img_path.split("/")[-1] else 0 data=Image.open(img_path) if self.transforms: data=self.transforms(data) return data,label def __len__(self): return len(self.imgs) dataset=DogCat('N:/百度網盤/kaggle/DogCat/', transforms=transforms) img,label=dataset[0]#相當於呼叫dataset.__getitem__(0) for img,label in dataset: print(img.size(),label)
三、ImageFolder
下面介紹一個會經常使用到的Dataset——ImageFolder,它的實現和上述DogCat很相似。
四、DataLoader
DataLoader載入資料
Dateset只負責資料的抽象,一次呼叫__getitem__
只返回一個樣本。
在訓練神經網路時,是對一個batch的資料進行操作,同時還要進行shuffle和並行加速等。
對此,pytorch
提供了DataLoader
幫助我們實現這些功能。
dataloader=DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataloader是一個可以迭代的物件
五、sampler取樣模組