【Pytorch】多個數據集聯合讀取
深度學習好比煉丹,框架就是丹爐,網路結構及演算法就是單方,而資料集則是原材料。現在世面上很多煉丹手冊都是針對單一資料集進行煉丹,有了這些手冊我們就能夠很容易進行煉丹,但為了練好丹,我們常常收集各種公開的資料集,並構建私有資料集,此時,便會遇到如何更好的使用多個數據進行練丹的問題。
本文將使用pytorch這個丹爐,介紹如何聯合讀取多個原材料,而不是從新制作原材料和標籤。
1、Pytorch的ConcatDataset介紹
class ConcatDataset(Dataset): """ Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner. Arguments: datasets (sequence): List of datasets to be concatenated """ @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets): super(ConcatDataset, self).__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] @property def cummulative_sizes(self): warnings.warn("cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2) return self.cumulative_sizes
首先,ConcatDataset繼承自Dataset類。
其次,ConcatDataset的建構函式要求一個列表L作為輸入,其包含若干個資料集的。建構函式會計算出一個cumulative size列表,裡面存放了”把L中的第n個數據集加上後一共有多少個樣本“的序列。
然後,ConcatDataset重寫了__len__方法,返回cumulative_size[-1],也就是若干個資料集的總樣本數;
最後,重寫了__getitem__,當給定索引 idx 的時候,會計算出該idx對應那個資料集及在這個資料集中的位置,這樣就可以訪問這個資料了。
2、多個數據集聯合讀取示例
假設我們需要讀取MNIST、CIFAR10和CIFAR100三個資料集。
首先,這三個資料集在torchvision中已經實現,呼叫方式如下:
mnist_data = MNIST('./data', train=True, download=True)
cifar10_data = CIFAR10('./data', train=True, download=True)
cifar100_data = CIFAR100('./data', train=True, download=True)
如果是其他資料集也要先實現讀取;
其次,定義一個數據種類和其訪問介面的字典:
_DATASETS = { 'MNIST': MNIST, 'CIFAR10': CIFAR10, 'CIFAR100': CIFAR100, }
然後,定義一個數據信息類,存放資料地址等資訊:
class DatasetCatalog:
DATASETS = {
'MNIST': {
"root": "./data",
},
'CIFAR10': {
"root": "./data",
},
'CIFAR100': {
"root": "./data",
}
}
@staticmethod
def get(name):
if "MNIST" in name:
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=attrs["root"],
)
return dict(factory="MNIST", args=args)
elif "CIFAR10" in name:
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=attrs["root"],
)
return dict(factory="CIFAR10", args=args)
elif "CIFAR100" in name:
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=attrs["root"],
)
return dict(factory="CIFAR100", args=args)
raise RuntimeError("Dataset not available: {}".format(name))
最後,定義一個製作資料集的函式,通過dataset_list指定需要載入的資料集名稱,對於train模式,會返回合併後的資料集,對於val模式,返回各自的val資料集
def make_dataset(dataset_list, train=True, transform=None, target_transform=None, download=True):
assert len(dataset_list) > 0
data_sets = []
for dataset_name in dataset_list:
catalog = DatasetCatalog.get(dataset_name)
args = catalog['args']
factory = _DATASETS[catalog['factory']]
args['train'] = train
args['transform'] = transform
args['target_transform'] = target_transform
args['download'] = download
if factory == MNIST:
data_set = factory(**args)
elif factory == CIFAR10:
data_set = factory(**args)
elif factory == CIFAR100:
data_set = factory(**args)
data_sets.append(data_set)
if not train:
return data_sets
data_set = data_sets[0]
if len(data_sets) > 1:
data_set = ConcatDataset(data_sets)
return data_set
具體看一下如何呼叫吧:
if __name__ == "__main__":
dataset_list = ["MNIST", "CIFAR10", "CIFAR100"]
concat_data = make_dataset(dataset_list, train=True, download=True)
for i, (data, target) in enumerate(concat_data):
print(np.array(data).shape)
print(target)
獲取了concat_data後,就可以通過dataloader來定義loader了。
對於其他資料集或私有資料集,可以改一改,能夠實現任何想要的輸出。
好了,本文就到這裡了。
老話一句,關注公眾號:AI約讀社。後臺回覆:concat,獲取完整程式碼。
也歡迎加微信群交流討論。