1. 程式人生 > >pytroch建立自己的Dataset和Dataloader.

pytroch建立自己的Dataset和Dataloader.

首先是引入需要的模組:

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

然後繼承Dataset類,重寫它的三個方法:

class PointDataSet(Dataset):
    def __init__(self):
        clouds, labels=get_train_data()
        self.x_data=clouds
        self.y_data=labels
        self.lenth=clouds.size(0)
    def __getitem__(self, index):
        return self.x_data[index],self.y_data[index]
    def __len__(self):
        return self.lenth

第一個函式是建構函式,也可以理解為是初始化函式,在這裡一般完成資料載入賦值給self下的變數(也就是例項化之後才能引用的變數,區別於靜態變數)。其中的x_data,y_data,length這些變數名字不是固定的,可以按照喜好命名,但也要符合規範。

第二個函式是根據索引獲取資料的方法,在使用迭代器不斷地獲取變數的時候,就會用到這個方法。其中的return函式後面的返回值,可以自定義返回值的數量,可以寫成        return index, self.x_data[index],self.y_data[index]

第三個函式是獲取資料的長度。

接下來是例項化這個物件,例項化後的Dataset用來作為構造Dataloader的引數:

point_data_set=PointDataSet()
data_loader=DataLoader(dataset=point_data_set,batch_size=32,shuffle=True)