pytorch (四) 資料載入
阿新 • • 發佈:2020-07-23
自定義載入資料
torch.utils.data.Dataset是一個抽象類,使用者想要載入自定義的資料只需要繼承這個類,並且覆寫其中的兩個方法即可:
- __len__:實現len(dataset)返回整個資料集的大小。
- __getitem__用來獲取一些索引的資料,使dataset[i]返回資料集中第i個樣本。
- 不覆寫這兩個方法會直接返回錯誤。
from torch.utils.data import DataLoader,Dataset class MyData(Dataset): #繼承Dataset def __init__(self, root_dir, transform=None): #初始化圖片路徑,一些變換操作。 self.root_dir = root_dir #檔案目錄 self.transform = transform #變換 self.images = os.listdir(self.root_dir)#目錄裡的所有檔案 def __len__(self):#返回整個資料集的大小 return len(self.images) def __getitem__(self,index):#根據索引index返回dataset[index] image_index = self.images[index]#根據索引index獲取該圖片 img_path = os.path.join(self.root_dir, image_index)#獲取索引為index的圖片的路徑名 img = io.imread(img_path)# 讀取該圖片 label = img_path.split('\\')[-1].split('.')[0]# 根據該圖片的路徑名獲取該圖片的label sample = {'image':img,'label':label}#根據圖片和標籤建立字典 if self.transform: sample = self.transform(sample)#對樣本進行變換 return sample #返回該樣本
之後使用torch.utils.data.DataLoader載入資料
data = MyData('path',transform=None)#初始化類,設定資料集所在路徑以及變換
dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader載入資料
載入時不要涉及預處理,把該預處理的都提前做完。比如resize事先處理完,crop,flip和normalize在載入時候處理。