1. 程式人生 > 其它 >MNIST資料集讀取-datasets.MNIST

MNIST資料集讀取-datasets.MNIST

技術標籤:影象處理深度學習

#%%

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import sys
batch_size = 2

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081, ))
    ])

#train = true是訓練集,false為測試集
MNIST_dataset_train = datasets.MNIST(root='./data/mnist', train=False, download=False, transform=transform)
dataloaders_train = DataLoader(dataset=MNIST_dataset_train, batch_size=batch_size, shuffle=True)

#%% 訓練集資料60000張,每次迴圈datasets,輸出x,y;x為N*1*28*28的影象,y為1*N的label
#測試集資料10000張,10000個標籤
i = 0
for x,y in dataloaders_train:
    #獲取一張圖片,和一個圖片的標籤
    if i==0:
        print('label:',y[0])
        plt.imshow(x[0,0,:,:])
        plt.pause(0.001)
    else:
        sys.exit()
    
    i+=1
        
    
    
label: tensor(3)