記錄對於__getitem__的理解
阿新 • • 發佈:2018-12-13
port torch as t from PIL import Image from torch.utils import data import os import numpy as np class DogCat(data.Dataset): def __init__(self,root): imgs=os.listdir(root) self.imgs=[os.path.join(root,img) for img in imgs] # test1: data/test1/8973.jpg # train: data/train/cat.10004.jpg def __getitem__(self,index): img_path=self.imgs[index] label=1 if 'dog' in img_path.split('/')[-1] else 0 pil_img=Image.open(img_path) array=np.asarray(pil_img) data=t.from_numpy(array) return data,label def __len__(self): return len(self.imgs) datasets=DogCat('./data/train') for img,label in datasets: print(img.size(),img.float().mean(),label)
索引的時候,會自動呼叫__getitem__方法,然後單張的讀取圖片。