1. 程式人生 > >CIFAR10資料集取一張視覺化儲存

CIFAR10資料集取一張視覺化儲存

transform = transforms.Compose([
        transforms.Resize(96),
        transforms.ToTensor()
        # transforms.Normalize((.5, .5, .5), (.5, .5, .5))
    ])

    test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False)
    # 得到一個隨機的訓練圖片
    dataiter = iter(test_loader)
    images, labels = dataiter.next()


    images = torch.squeeze(images) # (1,3,96,96) --> (3,96,96)
    images = torch.transpose(images, 0, -1) # (3,96,96) --> (96,96,3)
    img = images.numpy() # 將tensor轉換為numpy
    img = img_as_ubyte(img) # 這點很重要!!這些數值都不是在0-255,所以要轉換為unit8
    cv2.imwrite("./test.jpg", img) # 儲存為test.jpg
    cv2.imshow(img) # 視覺化
    cv2.waitKey(0)
    print "over."