1. 程式人生 > 其它 >pytorch讀取資料(Dataset, DataLoader, DataLoaderIter)

pytorch讀取資料(Dataset, DataLoader, DataLoaderIter)

技術標籤:學習總結pytorch

pytorch資料讀取

參考資料:
pytorch資料讀取
pytorch對nlp資料的處理部落格(以短文字匹配為例)
dataloader使用教程部落格
pytorch使用DataLoader對資料集進行批處理簡單示例

Pytorch的資料讀取主要包含三個類:

  1. Dataset
  2. DataLoader
  3. DataLoaderIter

這三者是依次封裝的關係,Dataset被裝進DataLoader,DataLoder被裝進DataLoaderIter。

Dataloader的處理邏輯是先通過Dataset類裡面的__getitem__函式獲取單個的資料,然後組合成batch,再使用collate_fn所指定的函式對這個batch做一些操作,比如padding。

torch.utils.data.Dataset

是一個抽象類,自定義的Dataset需要繼承它並實現兩個成員方法

  1. __getitem__():從資料集中得到一個數據片段(如:資料,標籤)
  2. __len__():返回整個資料集的長度

自定義Dataset基本的框架是:

class CustomDataset(data.Dataset):#需要繼承data.Dataset 
    def __init__(self): 
        # TODO 
        # 1. Initialize file path or list of file names. 
        pass 
        def
__getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #這裡需要注意的是,第一步:read one data,是一個data
pass def __len__(self): # You should change 0 to the total size of your dataset. return 0

以短文字匹配為例,判斷兩個文字是否相似,資料格式為(句子1,句子2,是否相似標籤0或1,示例如下圖所示:
在這裡插入圖片描述

對於這樣的資料集,我們構建的Dataset類可以是:

class Dataset(torch.utils.data.Dataset):
    def __init__(self, texta, textb, label):
        self.texta = texta
        self.textb = textb
        self.label = label
    
    def __getitem__(self, item):
        return self.texta[item], self.textb[item], self.label[item]
        
    def __len__(self):
        return len(self.texta)

接下來可以向自己實現的Dataset類中傳值:

train_data = Dataset(train_texta, train_textb, train_label)

torch.utils.data.DataLoader

類的定義為:

class torch.utils.data.DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False, 
    sampler=None, 
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=<function default_collate>,
    pin_memory=False, 
    drop_last=False
)

主要的引數有:

  1. dataset:即我們自定義的Dataset
  2. collate_fn:這個函式用來打包batch。定義如何把一批dataset的例項轉換為包含mini-batch資料的張量。可以通過自定義collate_fn=myfunction來設計資料打包的方式,通常在myfunction函式中做padding,將同一個batch中不一樣長的句子補全成一樣的長度。
  3. num_workers:設定>=1時,可多執行緒讀資料。設定=0時,單執行緒讀資料
  4. shuffle:代表是否要打亂資料,一般對於訓練集資料都是要打亂的,驗證集可以打亂也可以不打亂,但是測試集資料不打亂。

這個類是DataLoaderIter的一個框架,一共做了兩件事情:

  1. 定義成員變數,到時候賦給DataLoaderIter
  2. 有一個__iter__()函式,把自己裝進DataLoaderIter中
def __iter__(self):
    return DataLoaderIter(self)

仍然以短文字匹配任務為例,上面已經得到了train_data,然後就可以使用DataLoader進行處理,DataLoader每次返回一個batch的資料。

train_iter = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, collate_fn=myfunction)

這裡的myfunction可以定義為:

def myfunction(batch_data, pad=0):
    texta,textb,label = list(zip(*batch_data))#batch_data的結構是[([texta_1],[textb_1],[label_1]),([texta_2],[textb_2],[label_2]),...],所以需要使用zip函式對它解壓
    max_len_a = max([len(seq_a) for seq_a in texta])
    max_len_b = max([len(seq_b) for seq_b in textb])
    max_len = max(max_len_a,max_len_b) #這裡我使用的是一個batch中text_a或者是text_b的最大長度作為max_len,也可以自定義長度
    texta = [seq+[pad]*(max_len-len(seq)) for seq in texta]
    textb = [seq+[pad]*(max_len-len(seq)) for seq in textb]
    texta = torch.LongTensor(texta)
    textb = torch.LongTensor(textb)
    label = torch.FloatTensor(label)
    return (texta,textb,label)

torch.utils.data.dataloader.DataLoaderIter

這裡其實是迭代呼叫DataLoader的過程。一般載入資料的整個流程為:

class Dataset(Dataset)
    # 自定義的Dataset
    # 返回(data, label)

dataset = Dataset()
dataloader = DataLoader(dataset, ...)
for data in dataloader:
    # training

在for迴圈中,總共有三個操作:

  1. 呼叫DataLoder的__iter__()方法,產生一個DataLoaderIter
  2. 反覆呼叫DataLoaderIter 的__next__()來得到batch, 具體操作就是, 多次呼叫dataset的__getitem__()方法 (如果num_worker>0就多執行緒呼叫), 然後用collate_fn來把它們打包成batch
  3. 當資料讀完後, next()丟擲一個StopIteration異常, for迴圈結束, dataloader 失效

以上的短文字匹配的例子中,for迴圈部分可以寫成:

for batch_data in tqdm(train_iter):
    texta, textb, tag = map(lambda x: x.to(device), batch_data)
    output = model(texta, textb)