重新定義Pytorch中的TensorDataset,可實現transforms
阿新 • • 發佈:2019-03-02
The pre cas urn python lis hat pytorch als
class TensorsDataset(torch.utils.data.Dataset): ''' A simple loading dataset - loads the tensor that are passed in input. This is the same as torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor. Target tensor can also be None, in which case it is not returned. ''' def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None): if target_tensor is not None: assert data_tensor.size(0) == target_tensor.size(0) self.data_tensor = data_tensor self.target_tensor = target_tensor if transforms is None: transforms = [] if target_transforms is None: target_transforms = [] if not isinstance(transforms, list): transforms = [transforms] if not isinstance(target_transforms, list): target_transforms = [target_transforms] self.transforms = transforms self.target_transforms = target_transforms def __getitem__(self, index): data_tensor = self.data_tensor[index] for transform in self.transforms: data_tensor = transform(data_tensor) if self.target_tensor is None: return data_tensor target_tensor = self.target_tensor[index] for transform in self.target_transforms: target_tensor = transform(target_tensor) return data_tensor, target_tensor def __len__(self): return self.data_tensor.size(0)
重新定義Pytorch中的TensorDataset,可實現transforms