1. 程式人生 > 其它 >MNIST 資料集、資料載入

MNIST 資料集、資料載入

目錄

MNIST 資料集

機器學習的入門就是MNIST。

MNIST 資料集來自美國國家標準與技術研究所,是NIST(National Institute of Standards and Technology)的縮小版,訓練集(training set)由來自 250 個不同人手寫的數字構成,其中 50% 是高中學生,50% 來自人口普查局(the Census Bureau)的工作人員,測試集(test set)也是同樣比例的手寫數字資料。

獲取MNIST
MNIST 資料集可在http://yann.lecun.com/exdb/mnist/獲取,圖片是以位元組的形式進行儲存,它包含了四個部分:

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個標籤)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標籤)

此資料集中,訓練樣本:共60000個,其中55000個用於訓練,另外5000個用於驗證。測試樣本:共10000個,驗證資料比例相同。

from torchvision.datasets import MNIST
mnist_train = MNIST(root='./MNIST_data', train=True, download=True, transform=transforms.PILToTensor())

資料載入

from torch.utils.data import DataLoader
from torchvision.utils import make_grid
dataloader = DataLoader(dataset=mnist_train, batch_size=2, shuffle=True, num_workers=2)
for (images, labels) in dataloader:
    print(labels)
    image = make_grid(images).permute(1, 2, 0).numpy()
    plt.imshow(image)
    plt.show()
    exit()

其中引數含義:

  1. dataset:提前定義的dataset的例項
  2. batch_size:傳入資料的batch的大小,常用128,256等等
  3. shuffle:bool型別,表示是否在每次獲取資料的時候提前打亂資料
  4. num_workers:載入資料的執行緒數

transforms

由於 DataLoader 這個載入器只能載入 tensors, numpy arrays, numbers, dicts or lists

但是 found <class 'PIL.Image.Image'>,所以就很尷尬,我們需要將圖片轉換一下

transforms 用於圖形變換,在使用時我們還可以使用 transforms.Compose將一系列的transforms操作連結起來。

  • torchvision.transforms.Compose([ ts,ts,ts... ])ts為transforms操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

大多數情況下我們不會只transforms 一下,所以可以用如下方案

from torchvision import transforms
transforms.Compose(
    [  #文件  https://pytorch.org/vision/stable/transforms.html
        transforms.ToPILImage(),  # 轉成PIL圖片
        # transforms.Resize(size),  # 縮放
        transforms.ToTensor(),  # 變張量
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]
)

介紹一個概念:

transforms 處理過後,會把通道移到最前邊。比如 MNIST h*w*c 為:28281

tensor處理完,通道數會提前,並且做了軸交換,變為了 c*h*w 為:12828

至於為什麼要這麼設計?聽傳言是做矩陣加減乘除以及卷積等運算是需要呼叫cuda和cudnn的函式的,而這些介面都設成成 chw 格式了