1. 程式人生 > 實用技巧 >pytorch (四) 資料載入

pytorch (四) 資料載入

自定義載入資料

torch.utils.data.Dataset是一個抽象類,使用者想要載入自定義的資料只需要繼承這個類,並且覆寫其中的兩個方法即可:

  1. __len__:實現len(dataset)返回整個資料集的大小。
  2. __getitem__用來獲取一些索引的資料,使dataset[i]返回資料集中第i個樣本。
  3. 不覆寫這兩個方法會直接返回錯誤。
from torch.utils.data import DataLoader,Dataset
class MyData(Dataset): #繼承Dataset
    def __init__(self, root_dir, transform=None): #初始化圖片路徑,一些變換操作。
        self.root_dir = root_dir   #檔案目錄
        self.transform = transform #變換
        self.images = os.listdir(self.root_dir)#目錄裡的所有檔案
    
    def __len__(self):#返回整個資料集的大小
        return len(self.images)
    
    def __getitem__(self,index):#根據索引index返回dataset[index]
        image_index = self.images[index]#根據索引index獲取該圖片
        img_path = os.path.join(self.root_dir, image_index)#獲取索引為index的圖片的路徑名
        img = io.imread(img_path)# 讀取該圖片
        label = img_path.split('\\')[-1].split('.')[0]# 根據該圖片的路徑名獲取該圖片的label
        sample = {'image':img,'label':label}#根據圖片和標籤建立字典
        
        if self.transform:
            sample = self.transform(sample)#對樣本進行變換
        return sample #返回該樣本

之後使用torch.utils.data.DataLoader載入資料

data = MyData('path',transform=None)#初始化類,設定資料集所在路徑以及變換
dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader載入資料

載入時不要涉及預處理,把該預處理的都提前做完。比如resize事先處理完,crop,flip和normalize在載入時候處理。