pytorch 加載數據集
pytorch初學者,想加載自己的數據,了解了一下數據類型、維度等信息,方便以後加載其他數據。
1 torchvision.transforms實現數據預處理
transforms.Totensor()操作必須要有,將數據轉為張量格式。
2 torch.utils.data.Dataset實現數據讀取
要使用自己的數據集,需要構建Dataset子類,定義子類為MyDataset,在MyDataset的init函數中定義path_dict變量,來獲取不同類型的數據的路徑。
定義子類MyDataset時,必須要重載兩個函數 getitem 和 len,
__getitem__:實現數據集的下標索引,返回對應的數據及標簽;
__len__:返回數據集的大小。
設加載的數據集大小為L;
定義MyDataset實例:my_datasets = MyDataset(data_dir, transform = data_transform) 。
my_datasets 由L個tuple組成,len(my_datasets) = L;
每個tuple長度為2:0:tensor 樣本(Channel,Height,Width)
1:int 標簽
3 torch.utils.data.DataLoader實現數據集加載
torch.utils.data.DataLoader()合成數據並提供叠代訪問,由兩部分組成:
—dataset(Dataset):輸入要加載的數據,就是上面的my_datasets;
—batch_size,shuffle,sampler,batch_sampler,num_workers,collate_fn, drop_last,timeout,worker_init_fn等參數。
其中:batch_size:批尺寸,默認為1;
shuffle:是否在每個epoch開始隨機打亂數據,默認為False;
設data_loader長度為 l ;
加載數據:data_loader = DataLoader(my_datasets, batch_size = BATCH_SIZE, shuffle = True)
data_loader 由 l 個 tuple組成,l = len(data_loader) = len(my_datasets) / batch_size;
叠代訪問:
e 長度為2:0:int step 表示第幾個batch
1:list(長度為2)表示一個batch包含的所有樣本和標簽
0:tensor 樣本(Batch_size,Channel,Height,Width)
1:tensor 標簽 Batch_size
pytorch 加載數據集