1. 程式人生 > >【pytorch】訓練集的讀取

【pytorch】訓練集的讀取

pytorch讀取訓練集是非常便捷的,只需要使用到2個類:

(1)torch.utils.data.Dataset

(2)torch.utils.data.DataLoader

常用資料集的讀取

1、torchvision.datasets的使用

對於常用資料集,可以使用torchvision.datasets直接進行讀取。torchvision.dataset是torch.utils.data.Dataset的實現

該包提供了以下資料集的讀取

  • MNIST
  • COCO (Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10

下面以cifar10為例:

import torch
import torchvision
from PIL import Image

cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)
print(cifarSet[0])
img, label = cifarSet[0]
print (img)
print (label)
print (img.format, img.size, img.mode)
img.show()

2、例項化torch.utils.data.DataLoader

mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

# torch.utils.data.DataLoader
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)

下面就可以進行讀取資料的顯示,以進行簡單測試是否讀取成功:

for i, data in enumerate(cifarLoader, 0):
    print(data[i][0])
    # PIL
    img = transforms.ToPILImage()(data[i][0])
    img.show()
    break

自定義標籤資料集的讀取

1、實現torch.utils.data.Dataset

假設我們有一個標籤test_images.txt,內容如下:


對應的影象位於images目錄下。

首先要繼承torch.utils.data.Dataset類,完成影象及標籤的讀取。

import os
import torch
import torch.utils.data as data
from PIL import Image

def default_loader(path):
    return Image.open(path).convert('RGB')

class myImageFloder(data.Dataset):
    def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
        fh = open(label)
        c=0
        imgs=[]
        class_names=[]
        for line in  fh.readlines():
            if c==0:
                class_names=[n.strip() for n in line.rstrip().split('	')]
            else:
                cls = line.split() 
                fn = cls.pop(0)
                if os.path.isfile(os.path.join(root, fn)):
                    imgs.append((fn, tuple([float(v) for v in cls])))
            c=c+1
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.Tensor(label)

    def __len__(self):
        return len(self.imgs)
    
    def getName(self):
        return self.classes

2、例項化torch.utils.data.DataLoader

mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

# torch.utils.data.DataLoader
imgLoader = torch.utils.data.DataLoader(
         myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), 
         batch_size= 2, shuffle= False, num_workers= 2)

for i, data in enumerate(imgLoader, 0):
    print(data[i][0])
    # opencv
    img2 = data[i][0].numpy()*255
    img2 = img2.astype('uint8')
    img2 = np.transpose(img2, (1,2,0))
    img2=img2[:,:,::-1]#RGB->BGR
    cv2.imshow('img2', img2)
    cv2.waitKey()
    break


---------------------------------------------------------------------------------------------------

在各方小夥伴的努力和支援下,pytorch中文文件 第一版終於上線啦!!!(鼓掌)文件還有很多小瑕疵,但是大體可以放心使用了~我們遵循快速迭代的原則,所以趕緊上線第一版來接受廣大開源社群的意見和建議。歡迎加入我們!

中文翻譯組QQ群:628478868

還有pytorch專案交流群:613523596

歡迎關注!