1. 程式人生 > >PyTorch讀取Cifar資料集並顯示圖片

PyTorch讀取Cifar資料集並顯示圖片

首先了解一下需要的幾個類所在的package

這裡寫圖片描述

from torchvision import transforms, datasets as ds
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

#transform = transforms.Compose是把一系列圖片操作組合起來,比如減去畫素均值等。
#DataLoader讀入的資料型別是PIL.Image
#這裡對圖片不做任何處理,僅僅是把PIL.Image轉換為torch.FloatTensor,從而可以被pytorch計算
transform = transforms.Compose( [ transforms.ToTensor() ] )

Step 1,得到torch.utils.data.Dataset例項。

  • torch.utils.data.Dataset是一個抽象類,CIFAR100是它的一個例項化子類
  • train=True,讀取訓練集;train=False,讀取測試集
  • download=False,不下載。如果為True,則先檢查root下有無該資料集,如果沒有就先下載。
train_set = ds.CIFAR100(root='.', train=True
, transform=transform, target_transform=None, download=True)

Step 2,把Dataset封裝成torch.utils.data.DataLoader

data_loader = DataLoader(dataset=train_set,
                         batch_size=1,
                         shuffle=False,
                         num_workers=2)


# # 生成torch.utils.data.DataLoaderIter
# # 不過DataLoaderIter它會被DataLoader自動建立並且呼叫,我們用不到 # data_iter = iter(data_loader) # images, labels = next(data_iter)

step 3,從DataLoader裡讀取資料,並將圖片顯示出來。

注意:
1)使用for...in...迴圈讀取資料的時候,會自動呼叫DataLoader裡的__next__()函式
而且只能對Tensor例項進行迭代,所以之前的transforms必須最後加一個transforms.ToTensor()
2)顯示圖片有兩種方式:Image.show()plt.imshow(ndarray)
Image.show():
通過transforms.ToPILImage()FloatTensor轉化為Image
plt.imshow(ndarray)
通過FloatTensor.numpy()轉化為ndarray,再呼叫plt.imshow()

to_pil_image = transforms.ToPILImage()
cnt = 0
for image,label in data_loader:
    if cnt>=3:      # 只顯示3張圖片
        break
    print(label)    # 顯示label

    # 方法1:Image.show()
    # transforms.ToPILImage()中有一句
    # npimg = np.transpose(pic.numpy(), (1, 2, 0))
    # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一維
    img = to_pil_image(image[0])
    img.show()

    # 方法2:plt.imshow(ndarray)
    img = image[0]      # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一維
    img = img.numpy()   # FloatTensor轉為ndarray
    img = np.transpose(img, (1,2,0))    # 把channel那一維放到最後

    # 顯示圖片
    plt.imshow(img)
    plt.show()

    cnt += 1

另外補一句np.transpose()的用法。
第一個引數是要transpose的圖片;
第二個是shape。比如一個ndarray是(channel, height, width),如果給第二個引數(height, width,channel),就會把第0維channel整個搬到最後。