1. 程式人生 > 其它 >PyTorch筆記之Dataset 和 Dataloader

PyTorch筆記之Dataset 和 Dataloader

技術標籤:pytorch深度學習

PyTorch筆記之Dataset 和 Dataloader

PyTorch筆記之 Dataset 和 Dataloader
        </h1>
        <div class="clear"></div>
        <div class="postBody">

簡介#

在 PyTorch 中,我們的資料集往往會用一個類去表示,在訓練時用 Dataloader 產生一個 batch 的資料

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

比如官方例子中對 CIFAR10 影象資料集進行分類,就有用到這樣的操作,具體程式碼如下所示

複製程式碼
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=’./data’, train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

複製程式碼

簡單說,用 一個類 抽象地表示資料集,而 Dataloader 作為迭代器,每次產生一個 batch 大小的資料,節省記憶體

Dataset#

Dataset 是 PyTorch 中用來表示資料集的一個抽象類,我們的資料集可以用這個類來表示,至少覆寫下面兩個方法即可

這返回資料前可以進行適當的資料處理,比如將原文用一串數字序列表示

  • __len__:資料集大小
  • __getitem__:實現這個方法後,可以通過下標的方式( dataset[i] )的來取得第 ii 個數據

下面我們來為編寫一個類表示一個情感二分類資料集,繼續用蘇神整理的資料集

https://github.com/bojone/bert4keras/tree/master/examples/datasets

資料集沒有表頭,只有2列,一列是評論(文字),另一列是標籤,以製表符進行分隔

複製程式碼
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class SentimentDataset(Dataset):
def init(self, path_to_file):
self.dataset = pd.read_csv(path_to_file, sep="\t", names=[“text”, “label”])
def len(self):
return len(self.dataset)
def getitem(self, idx):
text = self.dataset.loc[idx, “text”]
label = self.dataset.loc[idx, “label”]
sample = {“text”: text, “label”: label}
return sample

複製程式碼

Dataloader#

基本使用#

Dataloader 就是一個迭代器,最基本的使用就是傳入一個 Dataset 物件,它就會根據引數 batch_size 的值生成一個 batch 的資料

複製程式碼
if __name__ == "__main__":
    sentiment_dataset = SentimentDataset("sentiment.test.data")
    sentiment_dataloader = DataLoader(sentiment_dataset, batch_size=4, shuffle=True, num_workers=2)
    for idx, batch_samples in enumerate(sentiment_dataloader):
        text_batchs, text_labels = batch_samples["text"], batch_samples["label"]
        print(text_batchs)
複製程式碼

Sampler#

PyTorch 提供了 Sampler 模組,用來對資料進行取樣,可以在 DataLoader 的通過 sampler 引數呼叫

一般我們的載入訓練集的 dataloader ,shuffle引數都會設定為True ,這時候使用了一個預設的取樣器——RandomSampler

當 shuffle 設定為 False 時,預設使用的是 SequencetialSampler,其實就是按順序取出資料集中的元素

在 PyTorch 中預設實現了以下 Sampler,如果我們要使用別的 Sampler, shuffle 要設定為 False

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

SubsetRandomSampler 常用來將資料集劃分為訓練集和測試集,比如這裡就訓練集和測試集按7:3 進行分割

複製程式碼
n_train = len(sentiment_train_set)
split = n_train // 3

indices = list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
train_loader = DataLoader(sentiment_train_set, sampler=train_sampler, shuffle=False) valid_loader = DataLoader(sentiment_train_set, sampler=valid_sampler, shuffle=False)
複製程式碼

具體推薦下面的博文,講得挺詳細的

一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關係

https://www.cnblogs.com/marsggbo/p/11541054.html

Pytorch Sampler詳解

https://www.cnblogs.com/marsggbo/p/11541054.html

collate_fn#

可以用來進行一些資料處理,比如在文字任務中,一般由於文字長度不一致,我們需要進行截斷或者填充。對於圖片,我們則希望它們有同樣的尺寸

我麼可以編寫一個函式,然後用這個引數呼叫它,下面是一個簡單的例子,我們把文字截斷成只有10個字元

複製程式碼
def truncate(data_list):
  """傳進一個batch_size大小的資料"""
  for data in data_list:
    text = data["text"]
    data["text"]=text[:10]
  return data_list

test_loader = DataLoader(sentiment_train_set, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=truncate)

複製程式碼

我們可以看看返回的內容是否已經經過截斷了

for i in test_loader:
  print(i)
  break

這時候返回的是一個列表而不是字典了,其中一個 batch 的返回結果如下,我們可以看到這裡一個樣本放在了一個字典中

[{'text': '看了一個通宵,實在是', 'label': 1}, 。。。, {'text': '看了攜程的其他使用者評', 'label': 0}]

下面是沒有使用 collate_fn 的返回結果,它會將資料和標籤分開,存放在一起,如下所示

{

'text':['3月1號訂的,3月15號還沒到貨 客服每天說下個工作日能到貨已經連續5天了 我無語。想早點兒看這本書的人還是去陶寶或卓越上訂吧,尤其是廣東省的朋友.噹噹送貨太沒保證了.',。。。, '非常純樸的故事,但包含了主人公坎坷的一生,活著就是痛苦,不得不佩服生命的韌性'],

'label': tensor([1, 。。。, 0])

}

作者: 那少年和狗

出處:https://www.cnblogs.com/dogecheng/p/11930535.html

版權:本文采用「CC BY 4.0」知識共享許可協議進行許可。

標籤: PyTorch
<div id="blog_post_info">
好文要頂 關注我 收藏該文 那少年和狗
關注 - 3
粉絲 - 48 +加關注 1 0
<div class="clear"></div>
<div id="post_next_prev">