1. 程式人生 > >pytorch程式碼閱讀筆記

pytorch程式碼閱讀筆記

1、資料載入:

class torch.utils.data.Dataset

  表示Dataset的抽象類。

  所有其他資料集都應該進行子類化。所有子類應該override __len____getitem__,前者提供了資料集的大小,後者支援整數索引,範圍從0到len(self)。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

  資料載入器。組合資料集和取樣器,並在資料集上提供單程序或多程序迭代器。

  引數:

    dataset (Dataset) – 載入資料的資料集。

    batch_size (int, optional) – 每個batch載入多少個樣本(預設: 1)。

    shuffle (bool, optional) – 設定為True時會在每個epoch重新打亂資料(預設: False)。

    sampler (Sampler, optional) – 定義從資料集中提取樣本的策略。如果指定,則忽略shuffle引數。

    num_workers (int, optional) – 用多少個子程序載入資料。0表示資料將在主程序中載入(預設: 0)

    collate_fn (callable, optional) –將一組取樣組合稱一個mini-batch。

    pin_memory (bool, optional) –如果True,dataloader會在返回張量前將其複製一份到CUDA pinned memory中。

    drop_last (bool, optional) – 如果資料集大小不能被batch size整除,則設定為True後可刪除最後一個不完整的batch。如果設為False並且資料集的大小不能被batch size整除,則最後一個batch將更小。(預設: False)

待續...