pytorch程式碼閱讀筆記
阿新 • • 發佈:2019-01-05
1、資料載入:
class torch.utils.data.Dataset
表示Dataset的抽象類。
所有其他資料集都應該進行子類化。所有子類應該override __len__和__getitem__,前者提供了資料集的大小,後者支援整數索引,範圍從0到len(self)。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
資料載入器。組合資料集和取樣器,並在資料集上提供單程序或多程序迭代器。
引數:
dataset (Dataset) – 載入資料的資料集。
batch_size (int, optional) – 每個batch載入多少個樣本(預設: 1)。
shuffle (bool, optional) – 設定為True時會在每個epoch重新打亂資料(預設: False)。
sampler (Sampler, optional) – 定義從資料集中提取樣本的策略。如果指定,則忽略shuffle引數。
num_workers (int, optional) – 用多少個子程序載入資料。0表示資料將在主程序中載入(預設: 0)
collate_fn (callable, optional) –將一組取樣組合稱一個mini-batch。
pin_memory (bool, optional) –如果True,dataloader會在返回張量前將其複製一份到CUDA pinned memory中。
drop_last (bool, optional) – 如果資料集大小不能被batch size整除,則設定為True後可刪除最後一個不完整的batch。如果設為False並且資料集的大小不能被batch size整除,則最後一個batch將更小。(預設: False)
待續...