1. 程式人生 > 其它 >Pytorch Dataset和Dataloader 學習筆記(二)

Pytorch Dataset和Dataloader 學習筆記(二)

Pytorch Dataset & Dataloader

Pytorch框架下的工具包中,提供了資料處理的兩個重要介面,Dataset 和 Dataloader,能夠方便的使用和按批裝載自己的資料集。

  1. 資料的預處理,載入資料並轉化為tensor格式

  2. 使用Dataset構建自己的資料

  3. 使用Dataloader裝載資料

【資料】連結:https://pan.baidu.com/s/1gdWFuUakuslj-EKyfyQYLA
提取碼:10d4
複製這段內容後開啟百度網盤手機App,操作更方便哦

資料的預處理與載入

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

## 1. 資料的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])

torch.utils.data.Dataset

Dataset抽象類,用於包裝構建自己的資料集,該類包括三個基本的方法:

  • __init__ 進行資料的讀取操作
  • __getitem__ 資料集需支援索引訪問
  • __len__ 返回資料集的長度
## 2. 構建自己的資料集
class Mydataset(Dataset):
    def __init__(self, train_data, label_data):
        self.train = train_data
        self.label = label_data
        self.len = len(train_data)

    def __getitem__(self, item):
        return self.train[item], self.label[item]

    def __len__(self):
        return self.len

dataset = Mydataset(x, y)
samples = dataset.__len__()
print("總樣本數:",samples)

torch.utils.data.Dataloader

Dataloader抽象類,構建可迭代的資料集裝載器,從Dataset例項物件中按batch_size裝載資料以送入訓練。包含以下幾個引數:

  • batch_size 批大小
  • shuffle 裝載的batch是否亂序
  • drop_last 不足batch大小的最後部分是否捨去
  • num_workers 是否多程序讀取資料
## 3. 建立資料集裝載器
train_loader = DataLoader(dataset=dataset,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True,
                          num_workers=4)

測試

if __name__ == "__main__":
    iteration = 0
    for train_data, train_label in train_loader:
        print("x: ", train_data, "\ny: ", train_label)
        iteration += 1
    ### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
    print("每個epoch迭代次數:",iteration)

完整程式碼

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

## 1. 資料的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])

## 2. 構建自己的資料集
class Mydataset(Dataset):
    def __init__(self, train_data, label_data):
        self.train = train_data
        self.label = label_data
        self.len = len(train_data)

    def __getitem__(self, item):
        return self.train[item], self.label[item]

    def __len__(self):
        return self.len

dataset = Mydataset(x, y)

## 3. 建立資料集裝載器
train_loader = DataLoader(dataset=dataset,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True,
                          num_workers=4)

if __name__ == "__main__":
    iteration = 0
    samples = dataset.__len__()
    print("總樣本數:", samples)
    for train_data, train_label in train_loader:
        print("x: ", train_data, "\ny: ", train_label)
        iteration += 1
    ### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
    print("每個epoch迭代次數:",iteration)