pytorch進行Dataset資料載入
阿新 • • 發佈:2020-12-23
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms
import os
from PIL import Image
import numpy as np
import torch
class My_dataset(Dataset):
def __init__(self,root_path,is_train=True,is_miniTrain=False):
self.is_train=is_train
super ().__init__()
f=open(root_path,'r',encoding='utf-8')
data_list=f.readlines()
self.x=[]
self.y=[]
for i,data in enumerate(data_list):
data=data.rstrip()
self.x.append(data.split(',')[0])
self.y.append(data.split(',')[1:])
if is_miniTrain:
self.x=self.x[:700]
self.y=self.y[:700]
def __len__(self):
return len(self.x)
def __getitem__(self, index):
x=self.x[index] #'./img/my_data/TRAIN/2586_paste.png'
y=self.y[index]
# img=Image.open('./img/my_data/TRAIN/2586_paste.png')
# img.show()
# exit()
img=self.train_transform(Image.open(x)) if self.is_train \
else self.others_transform(Image.open(x)) # #(224, 224, 3)
lable=[]
for i in y:
lable.append(int(i))
# lable=np.array(lable).reshape(5,-1)
return img,lable
def train_transform(self,x):
return transforms.Compose([
transforms.RandomCrop(224,padding=28),
transforms.RandomRotation((0.5)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.486],std=[0.485,0.456,0.486])
])(x)
def others_transform(self,x):
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.486],std=[0.485,0.456,0.486])
])(x)