PyTorch學習記錄003-Dataset和DataLoader
阿新 • • 發佈:2020-12-09
1.utils.data包括Dataset和DataLoader
torch.utils.data.Dataset為抽象類,表示Dataset的抽象類,所有其他資料集都應該進行子類化,所有子類應該override,__len__和__getitem__,前者提供了資料集的大小,後者支援整數索引,範圍從0到len(self)。 自定義資料集需要繼承這個類,並實現兩個函式,一個是__len__,另一個是__getitem__前者提供資料的大小(size),後者通過給定索引獲取資料和標籤__getitem__一次只能獲取一個數據,所以需要通過torch.utils.data.DataLoader來定義一個新的迭代器,實現batch讀取。 首先定義獲取資料集的類,該類繼承基類Dataset,自定義一個數據集及對應標籤。
class TestDataset(data.Dataset): # 繼承Dataset def __init__(self): # 一些由2維向量表示的資料集 self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]]) # 這些是資料集對應的標籤 self.Label = np.asarray([0,1,0,1,2]) def __getitem__(self, index): # 把numpy轉換為tensor txt = torch.from_numpy(self.Data[index]) label = torch.tensor(self.Label[index]) return txt, label def __len__(self): return len(self.Data)
Test = TestDataset()
print(Test[2]) # 相當於呼叫__getitem__(2)
print(Test.__len__())
輸出:
(tensor([2, 1], dtype=torch.int32), tensor(0, dtype=torch.int32))
5
以上資料以tuple返回,每次只返回一個樣本。實際上,Dateset只負責資料的抽取,呼叫一次__getitem__只返回一個樣本。如果希望批量處理(batch),還要同時進行shuffle和並行加速等操作,可選擇DataLoader。
DataLoader的格式為:
data.DataLoader( dataset, # 載入的資料集 batch_size=1, # 批大小 shuffle=False, # 是否將資料打亂 sampler=None, # 樣本抽樣 batch_sampler=None, num_workers=0, # 使用多程序載入的程序數,0代表不適用多程序 collate_fn=<function *> # 如何將多個樣本資料拼成一個batch pin_memory=False, # 是否將資料儲存在pin memory中,pin memory中的資料轉到GPU會快一些 drop_last=False, # dataset中的資料個數可能不是batch_size的整數倍,drop_last為true會將多出來不足一個batch的資料丟棄 timeout=0, worker_init_fn=None, )
建立一個DataLoader:
Test = TestDataset()
test_loader = data.DataLoader(Test, batch_size = 2,
shuffle = False,
num_workers=2,
drop_last = True)
for i, traindata in enumerate(test_loader):
print('i:{}'.format(i))
Data, Label = traindata
print('data:',Data)
print('Label:', Label)
輸出:
i:0
data: tensor([[1, 2],
[3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
i:1
data: tensor([[2, 1],
[3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
從這個結果可以看出,這是批量讀取。我們可以像使用迭代器一樣使用它,比如對它進行迴圈操作。不過由於它不是迭代器,我們可以通過iter命令將其轉換為迭代器。
dataiter = iter(test_loader)
imgs,labels = next(dataiter)
一般用data.Dataset處理同一個目錄下的資料。如果資料在不同目錄下,因為不同的目錄代表不同類別(這種情況比較普遍),使用data.Dataset來處理就很不方便。不過,使用PyTorch另一種視覺化資料處理工具(即torchvision)就非常方便,不但可以自動獲取標籤,還提供很多資料預處理、資料增強等轉換函式。