1. 程式人生 > 其它 >pytorch進行Dataset資料載入

pytorch進行Dataset資料載入

技術標籤:資料處理深度學習pytorch

from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms
import os
from PIL import Image
import numpy as np
import torch

class My_dataset(Dataset):
    def __init__(self,root_path,is_train=True,is_miniTrain=False):
        self.is_train=is_train
        super
().__init__() f=open(root_path,'r',encoding='utf-8') data_list=f.readlines() self.x=[] self.y=[] for i,data in enumerate(data_list): data=data.rstrip() self.x.append(data.split(',')[0]) self.y.append(data.split(',')[1:]) if
is_miniTrain: self.x=self.x[:700] self.y=self.y[:700] def __len__(self): return len(self.x) def __getitem__(self, index): x=self.x[index] #'./img/my_data/TRAIN/2586_paste.png' y=self.y[index] # img=Image.open('./img/my_data/TRAIN/2586_paste.png')
# img.show() # exit() img=self.train_transform(Image.open(x)) if self.is_train \ else self.others_transform(Image.open(x)) # #(224, 224, 3) lable=[] for i in y: lable.append(int(i)) # lable=np.array(lable).reshape(5,-1) return img,lable def train_transform(self,x): return transforms.Compose([ transforms.RandomCrop(224,padding=28), transforms.RandomRotation((0.5)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.486],std=[0.485,0.456,0.486]) ])(x) def others_transform(self,x): return transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.486],std=[0.485,0.456,0.486]) ])(x)