1. 程式人生 > >pytorch 加載數據集

pytorch 加載數據集

類型 變量 必須 信息 info true path lis item

pytorch初學者,想加載自己的數據,了解了一下數據類型、維度等信息,方便以後加載其他數據。

1 torchvision.transforms實現數據預處理

transforms.Totensor()操作必須要有,將數據轉為張量格式。

2 torch.utils.data.Dataset實現數據讀取

要使用自己的數據集,需要構建Dataset子類,定義子類為MyDataset,在MyDataset的init函數中定義path_dict變量,來獲取不同類型的數據的路徑。

定義子類MyDataset時,必須要重載兩個函數 getitem 和 len,

__getitem__:實現數據集的下標索引,返回對應的數據及標簽;

__len__:返回數據集的大小。

設加載的數據集大小為L;

定義MyDataset實例:my_datasets = MyDataset(data_dir, transform = data_transform) 。

技術分享圖片

my_datasets 由L個tuple組成,len(my_datasets) = L;

每個tuple長度為2:0:tensor 樣本(Channel,Height,Width)

1:int 標簽

技術分享圖片

技術分享圖片 技術分享圖片

技術分享圖片

3 torch.utils.data.DataLoader實現數據集加載

torch.utils.data.DataLoader()合成數據並提供叠代訪問,由兩部分組成:

—dataset(Dataset):輸入要加載的數據,就是上面的my_datasets;

—batch_size,shuffle,sampler,batch_sampler,num_workers,collate_fn, drop_last,timeout,worker_init_fn等參數。

其中:batch_size:批尺寸,默認為1;

   shuffle:是否在每個epoch開始隨機打亂數據,默認為False;

設data_loader長度為 l ;

加載數據:data_loader = DataLoader(my_datasets, batch_size = BATCH_SIZE, shuffle = True)

data_loader 由 l 個 tuple組成,l = len(data_loader) = len(my_datasets) / batch_size;

叠代訪問:

技術分享圖片

技術分享圖片

e 長度為2:0:int step 表示第幾個batch

1:list(長度為2)表示一個batch包含的所有樣本和標簽

0:tensor 樣本(Batch_size,Channel,Height,Width)

1:tensor 標簽 Batch_size

技術分享圖片

技術分享圖片

pytorch 加載數據集