torchvision.utils.make_grid() 拼接圖片 mnist資料集(新手)
阿新 • • 發佈:2021-02-06
參考:部落格園
import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader import torchvision import matplotlib.pyplot as plt def image_show(images): images = images.numpy() images = images.transpose((1, 2, 0)) print(images.shape) plt.imshow(images) plt.show() def main(): train_dataset = datasets.MNIST(root='./datasets', train=False, download=False, transform=transforms.ToTensor()) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False) device = torch.device('cuda:0') # for batch_idx, (inputs, targets) in enumerate(train_loader): # inputs = inputs.to(device) # print(inputs.shape) inputs, targets = next(iter(train_loader)) print(inputs.shape) print(targets.shape) images = torchvision.utils.make_grid(inputs) print(f'images.shape:{images.shape}') image_show(images) if __name__=='__main__': main()