1. 程式人生 > >pytorch筆記(二)

pytorch筆記(二)

https://blog.csdn.net/u012609509/article/details/81264687
當載入MNIST、CIFAR10等自帶的資料集時,使用torchvision.datasets.MNIST(....),然後使用torch.utils.data.DataLoader讀取資料
當使用自定義的資料集時,應該先繼承torch.utils.data.Dataset類,並且重寫__len()__和—__getitem()方法__
```
# coding=utf-8
import matplotlib as mpl
mpl.use('tkagg')   # 除錯:agg;  執行: tkagg
import matplotlib.pyplot as plt
 
import os
import pandas as pd
import torch
 
"""
torch.utils.data.Dataset 是一個表示資料集的抽象類.
你自己的資料集一般應該繼承``Dataset``, 並且重寫下面的方法:
    1. __len__ 使用``len(dataset)`` 可以返回資料集的大小
    2. __getitem__ 支援索引, 以便於使用 dataset[i] 可以 獲取第i個樣本(0索引)
"""
from torch.utils.data import Dataset
 
 
"""
torch.utils.data中的DataLoader提供為Dataset類物件提供了:
    1.批量讀取資料
    2.打亂資料順序
    3.使用multiprocessing並行載入資料
    
    DataLoader中的一個引數collate_fn:可以使用它來指定如何精確地讀取一批樣本,
     merges a list of samples to form a mini-batch.
    然而,預設情況下collate_fn在大部分情況下都表現很好


注意:當使用torchvision中自帶的資料集時,download可一直設定為TURE,但root需要設定為壓縮檔案所在的根目錄,不需要具體到訓練集,當壓縮檔案已經存在時,不會從網上再次下載,如果不存在,則會下載到root所表示的根目錄。