MNIST資料集讀取-datasets.MNIST
阿新 • • 發佈:2021-01-28
#%% import torch from torchvision import datasets from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt import sys batch_size = 2 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, )) ]) #train = true是訓練集,false為測試集 MNIST_dataset_train = datasets.MNIST(root='./data/mnist', train=False, download=False, transform=transform) dataloaders_train = DataLoader(dataset=MNIST_dataset_train, batch_size=batch_size, shuffle=True) #%% 訓練集資料60000張,每次迴圈datasets,輸出x,y;x為N*1*28*28的影象,y為1*N的label #測試集資料10000張,10000個標籤 i = 0 for x,y in dataloaders_train: #獲取一張圖片,和一個圖片的標籤 if i==0: print('label:',y[0]) plt.imshow(x[0,0,:,:]) plt.pause(0.001) else: sys.exit() i+=1
label: tensor(3)