1. 程式人生 > 其它 >torchvision.datasets.CIFAR10()和torch.utils.data.DataLoader詳細介紹

torchvision.datasets.CIFAR10()和torch.utils.data.DataLoader詳細介紹

技術標籤:python資料分析人工智慧python深度學習

import torchvision  

trainset = torchvision.datasets.CIFAR10(root = data_path,
                                        train = True,
                                        download = True,
                                        transform=transform)
'''                                        
root :
cifar-10 的根目錄,data_path(路徑) train : True = 訓練集, False = 測試集 download : True = 從互聯上下載資料,並將其放在root目錄下,如果資料集已經下載,什麼都不幹。 transform:(可呼叫,可選)–接收PIL影象並返回轉換版本的函式/轉換。 '''
# 訓練資料集的載入器,自動將資料分割成batch,順序隨機打亂
batch_size = 32
trainloader= torch.utils.data.DataLoader(dataset=trainset ,
                                           batch_size=
batch_size, drop_last = True , shuffle=True, num_workers=4) ''' 1、dataset:這個就是PyTorch已有的資料讀取介面(比如torchvision.datasets.ImageFolder)或者自定義的資料介面的輸出,該輸出要麼是torch.utils.data.
Dataset類的物件,要麼是繼承自 torch.utils.data.Dataset類的自定義類的物件。 2、batch_size:每個batch載入多少個樣本(預設: 1),根據具體情況設定即可。 3、shuffle:設定為True時會在每個epoch重新打亂資料(預設: False),一般在訓練資料中會採用。 4、collate_fn (callable, optional):是用來處理不同情況下的輸入dataset的封裝,一般採用預設即可,除非你自定義的資料讀取輸出非常少見。合併樣本列表以形成一個 mini-batch. # callable可呼叫物件 5、batch_sampler:從註釋可以看出,其和batch_size、shuffle等引數是互斥的,一般採用預設。 6、sampler:從程式碼可以看出,其和shuffle是互斥的,一般預設即可。 7、num_workers:從註釋可以看出這個引數必須大於等於00的話表示資料匯入在主程序中進行,其他大於0的數表示通過多個程序來匯入資料,可以加快資料匯入速度。 8、pin_memory,註釋寫得很清楚了: pin_memory (bool, optional): 如果為真,資料載入器將在返回張量之前將其複製到CUDA固定記憶體中,然後再返回它們. 9、timeout(numeric, optional):是用來設定資料讀取的超時時間的,但超過這個時間還沒讀取到資料的話就會報錯.如果為正值,則為從工作人員收集批次的超時值.應始終是非負的,(預設:0. 10、drop_last (bool, optional): 設定為 True 如果資料集大小不能被批量大小整除的時候, 將丟掉最後一個不完整的batch,(預設:False). '''