1. 程式人生 > 實用技巧 >用torchvision.datasets.ImageFolder載入圖片資料集

用torchvision.datasets.ImageFolder載入圖片資料集

一、專案結構

二、程式碼

 1 data_loader = torch.utils.data.DataLoader(
 2     torchvision.datasets.ImageFolder('traing_dataset',
 3        transform=torchvision.transforms.Compose([
 4                 torchvision.transforms.Resize([28, 28]),       # 裁剪圖片
 5                 torchvision.transforms.Grayscale(1),           #
單通道 6 torchvision.transforms.ToTensor(), # 將圖片資料轉成tensor格式 7 torchvision.transforms.Normalize( # 歸一化 8 (0.1307,), (0.3081,)) 9 ])), 10 batch_size=10, shuffle=False)                  # 10張圖片

三、顯示效果

 1 def
plot_image(img, label, name): 2 fig = plt.figure() 3 for i in range(6):                                    # 只顯示6張 4 plt.subplot(2, 3, i+1)                               # 2行3列第i+1張 5 plt.tight_layout() 6 plt.imshow(img[i][0]*0.3081+0.1307, cmap='Greys', interpolation='
none') 7 plt.title("{}:{}".format(name, label[i].item()))                # 標題名稱 8 plt.xticks([]) 9 plt.yticks([]) 10 plt.show() 11 12 x, y = next(iter(data_loader))                                # 資料夾的名稱即為圖片的label 13 print(x.shape, y.shape) 14 plot_image(x, y, 'image')