1. 程式人生 > 實用技巧 >PyTorch學習記錄003-Dataset和DataLoader

PyTorch學習記錄003-Dataset和DataLoader

1.utils.data包括Dataset和DataLoader

  torch.utils.data.Dataset為抽象類,表示Dataset的抽象類,所有其他資料集都應該進行子類化,所有子類應該override,__len__和__getitem__,前者提供了資料集的大小,後者支援整數索引,範圍從0到len(self)。
  自定義資料集需要繼承這個類,並實現兩個函式,一個是__len__,另一個是__getitem__前者提供資料的大小(size),後者通過給定索引獲取資料和標籤__getitem__一次只能獲取一個數據,所以需要通過torch.utils.data.DataLoader來定義一個新的迭代器,實現batch讀取。
  首先定義獲取資料集的類,該類繼承基類Dataset,自定義一個數據集及對應標籤。
class TestDataset(data.Dataset): # 繼承Dataset
    def __init__(self):
        # 一些由2維向量表示的資料集
        self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]]) 
        # 這些是資料集對應的標籤
        self.Label = np.asarray([0,1,0,1,2])
        
    def __getitem__(self, index):
        # 把numpy轉換為tensor
        txt = torch.from_numpy(self.Data[index])
        label = torch.tensor(self.Label[index])
        return txt, label
    
    def __len__(self):
        return len(self.Data)

Test = TestDataset()
print(Test[2]) # 相當於呼叫__getitem__(2)
print(Test.__len__())

輸出:

(tensor([2, 1], dtype=torch.int32), tensor(0, dtype=torch.int32))
5
  以上資料以tuple返回,每次只返回一個樣本。實際上,Dateset只負責資料的抽取,呼叫一次__getitem__只返回一個樣本。如果希望批量處理(batch),還要同時進行shuffle和並行加速等操作,可選擇DataLoader。

DataLoader的格式為:

data.DataLoader(
	dataset,                # 載入的資料集
	batch_size=1,			# 批大小
	shuffle=False,  		# 是否將資料打亂
	sampler=None,			# 樣本抽樣
	batch_sampler=None,
	num_workers=0,			# 使用多程序載入的程序數,0代表不適用多程序
	collate_fn=<function *>	# 如何將多個樣本資料拼成一個batch
	pin_memory=False,		# 是否將資料儲存在pin memory中,pin memory中的資料轉到GPU會快一些
	drop_last=False,		# dataset中的資料個數可能不是batch_size的整數倍,drop_last為true會將多出來不足一個batch的資料丟棄
	timeout=0,
	worker_init_fn=None,
)

建立一個DataLoader:

Test = TestDataset()
test_loader = data.DataLoader(Test, batch_size = 2, 
				    	shuffle = False, 
				    	num_workers=2, 
				    	drop_last = True)
for i, traindata in enumerate(test_loader):
    print('i:{}'.format(i))
    Data, Label = traindata
    print('data:',Data)
    print('Label:', Label)

輸出:

i:0
data: tensor([[1, 2],
        [3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
i:1
data: tensor([[2, 1],
        [3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
  從這個結果可以看出,這是批量讀取。我們可以像使用迭代器一樣使用它,比如對它進行迴圈操作。不過由於它不是迭代器,我們可以通過iter命令將其轉換為迭代器。
dataiter = iter(test_loader)
imgs,labels = next(dataiter)
  一般用data.Dataset處理同一個目錄下的資料。如果資料在不同目錄下,因為不同的目錄代表不同類別(這種情況比較普遍),使用data.Dataset來處理就很不方便。不過,使用PyTorch另一種視覺化資料處理工具(即torchvision)就非常方便,不但可以自動獲取標籤,還提供很多資料預處理、資料增強等轉換函式。