CIFAR10資料集取一張視覺化儲存
阿新 • • 發佈:2018-12-11
transform = transforms.Compose([ transforms.Resize(96), transforms.ToTensor() # transforms.Normalize((.5, .5, .5), (.5, .5, .5)) ]) test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True) test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) # 得到一個隨機的訓練圖片 dataiter = iter(test_loader) images, labels = dataiter.next() images = torch.squeeze(images) # (1,3,96,96) --> (3,96,96) images = torch.transpose(images, 0, -1) # (3,96,96) --> (96,96,3) img = images.numpy() # 將tensor轉換為numpy img = img_as_ubyte(img) # 這點很重要!!這些數值都不是在0-255,所以要轉換為unit8 cv2.imwrite("./test.jpg", img) # 儲存為test.jpg cv2.imshow(img) # 視覺化 cv2.waitKey(0) print "over."