pytorch中的torch.utils.data.Dataset和torch.utils.data.DataLoader
阿新 • • 發佈:2018-12-17
首先看torch.utils.data.Dataset這個抽象類。可以使用這個抽象類來構造pytorch資料集。要注意的是以這個類構造的子類,一定要定義兩個函式一個是__len__,另一個是__getitem__,前者提供資料集size,而後者通過給定索引獲取資料和標籤。__getitem__一次只能獲取一個數據(不知道是不是強制性的),所以通過torch.utils.data.DataLoader來定義一個新的迭代器,實現batch讀取。首先我們來定義一個j簡單的資料集:
from torch.utils.data.dataset import Dataset import numpy as np
class TxtDataset(Dataset):#這是一個Dataset子類 def __init__(self): self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#特徵向量集合,特徵是2維表示一段文字 Label=np.asarray([1, 2, 0, 1, 2])#標籤是1維,表示文字類別 def __getitem__(self, index): txt=torch.LongTensor(self.Data[index]) label=torch.LongTensor(self.Label[index]) return txt, label #返回標籤 def __len__(self): return len(self.Data)
我們建立一個TxtDataset物件,並呼叫函式,注意__getitem__的呼叫要通過: 物件[索引]呼叫
Txt=TxtDataset()
print(Txt[1])
print(Txt.__len__())
#輸出:
(array([3, 4]), 2)
5
看到輸出中特徵向量和標籤是以tuple返回的。而此處得到樣本是一個不是批量的所以我們使用了torch.utils.data.DataLoader引數有 資料集物件(Dataset)、batc_size、shuffle(設定為真每個epoch會進行重置資料順序,一般在訓練資料中使用)、num_workers(設定多少個子程序可以使用,設定0表示在主程序中使用)
test_loader = DataLoader(Txt,batch_size=2,shuffle=False,
num_workers=4)
for i,traindata in enumerate(test_loader):
print('i:',i)
Data,Label=traindata
print('data:',Data)
print('Label:',Label)
輸出:
i: 0
data: tensor([[ 1, 2],
[ 3, 4]], dtype=torch.int32)
Label: tensor([ 1, 2], dtype=torch.int32)
i: 1
data: tensor([[ 2, 1],
[ 3, 4]], dtype=torch.int32)
Label: tensor([ 0, 1], dtype=torch.int32)
i: 2
data: tensor([[ 4, 5]], dtype=torch.int32)
Label: tensor([ 2], dtype=torch.int32)
在這個例子中設定批量為2,因此每次去出兩個樣本。除了文字資料可以這樣設定,圖片資料集也是可以的。