pytroch建立自己的Dataset和Dataloader.
阿新 • • 發佈:2018-11-17
首先是引入需要的模組:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
然後繼承Dataset類,重寫它的三個方法:
class PointDataSet(Dataset): def __init__(self): clouds, labels=get_train_data() self.x_data=clouds self.y_data=labels self.lenth=clouds.size(0) def __getitem__(self, index): return self.x_data[index],self.y_data[index] def __len__(self): return self.lenth
第一個函式是建構函式,也可以理解為是初始化函式,在這裡一般完成資料載入賦值給self下的變數(也就是例項化之後才能引用的變數,區別於靜態變數)。其中的x_data,y_data,length這些變數名字不是固定的,可以按照喜好命名,但也要符合規範。
第二個函式是根據索引獲取資料的方法,在使用迭代器不斷地獲取變數的時候,就會用到這個方法。其中的return函式後面的返回值,可以自定義返回值的數量,可以寫成 return index, self.x_data[index],self.y_data[index]
第三個函式是獲取資料的長度。
接下來是例項化這個物件,例項化後的Dataset用來作為構造Dataloader的引數:
point_data_set=PointDataSet()
data_loader=DataLoader(dataset=point_data_set,batch_size=32,shuffle=True)