1. 程式人生 > 其它 >【Pytorch】多個數據集聯合讀取

【Pytorch】多個數據集聯合讀取

技術標籤:模型訓練和部署pythonpytorch深度學習

深度學習好比煉丹,框架就是丹爐,網路結構及演算法就是單方,而資料集則是原材料。現在世面上很多煉丹手冊都是針對單一資料集進行煉丹,有了這些手冊我們就能夠很容易進行煉丹,但為了練好丹,我們常常收集各種公開的資料集,並構建私有資料集,此時,便會遇到如何更好的使用多個數據進行練丹的問題。

本文將使用pytorch這個丹爐,介紹如何聯合讀取多個原材料,而不是從新制作原材料和標籤。

1、Pytorch的ConcatDataset介紹

class ConcatDataset(Dataset):
    """
    Dataset to concatenate multiple datasets.
    Purpose: useful to assemble different existing datasets, possibly
    large-scale datasets as the concatenation operation is done in an
    on-the-fly manner.

    Arguments:
        datasets (sequence): List of datasets to be concatenated
    """

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets):
        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

首先,ConcatDataset繼承自Dataset類。

其次,ConcatDataset的建構函式要求一個列表L作為輸入,其包含若干個資料集的。建構函式會計算出一個cumulative size列表,裡面存放了”把L中的第n個數據集加上後一共有多少個樣本“的序列。

然後,ConcatDataset重寫了__len__方法,返回cumulative_size[-1],也就是若干個資料集的總樣本數;

最後,重寫了__getitem__,當給定索引 idx 的時候,會計算出該idx對應那個資料集及在這個資料集中的位置,這樣就可以訪問這個資料了。


2、多個數據集聯合讀取示例

假設我們需要讀取MNIST、CIFAR10和CIFAR100三個資料集。

首先,這三個資料集在torchvision中已經實現,呼叫方式如下:

mnist_data = MNIST('./data', train=True, download=True)
cifar10_data = CIFAR10('./data', train=True, download=True)
cifar100_data = CIFAR100('./data', train=True, download=True)

如果是其他資料集也要先實現讀取;

其次,定義一個數據種類和其訪問介面的字典:

_DATASETS = {
    'MNIST': MNIST,
    'CIFAR10': CIFAR10,
    'CIFAR100': CIFAR100,
}

然後,定義一個數據信息類,存放資料地址等資訊:

class DatasetCatalog:
    DATASETS = {
        'MNIST': {
            "root": "./data",
        },
        'CIFAR10': {
            "root": "./data",
        },
        'CIFAR100': {
            "root": "./data",
        }
    }

    @staticmethod
    def get(name):
        if "MNIST" in name:
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=attrs["root"],
            )
            return dict(factory="MNIST", args=args)
        elif "CIFAR10" in name:
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=attrs["root"],
            )
            return dict(factory="CIFAR10", args=args)
        elif "CIFAR100" in name:
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=attrs["root"],
            )
            return dict(factory="CIFAR100", args=args)

        raise RuntimeError("Dataset not available: {}".format(name))

最後,定義一個製作資料集的函式,通過dataset_list指定需要載入的資料集名稱,對於train模式,會返回合併後的資料集,對於val模式,返回各自的val資料集

def make_dataset(dataset_list, train=True, transform=None, target_transform=None, download=True):
    assert len(dataset_list) > 0
    data_sets = []
    for dataset_name in dataset_list:
        catalog = DatasetCatalog.get(dataset_name)
        args = catalog['args']
        factory = _DATASETS[catalog['factory']]
        args['train'] = train
        args['transform'] = transform
        args['target_transform'] = target_transform
        args['download'] = download
        if factory == MNIST:
            data_set = factory(**args)
        elif factory == CIFAR10:
            data_set = factory(**args)
        elif factory == CIFAR100:
            data_set = factory(**args)
        data_sets.append(data_set)

    if not train:
        return data_sets
    data_set = data_sets[0]
    if len(data_sets) > 1:
        data_set = ConcatDataset(data_sets)
    return data_set

具體看一下如何呼叫吧:

if __name__ == "__main__":
    dataset_list = ["MNIST", "CIFAR10", "CIFAR100"]
    concat_data = make_dataset(dataset_list, train=True, download=True)

    for i, (data, target) in enumerate(concat_data):
        print(np.array(data).shape)
        print(target)

獲取了concat_data後,就可以通過dataloader來定義loader了。

對於其他資料集或私有資料集,可以改一改,能夠實現任何想要的輸出。

好了,本文就到這裡了。

老話一句,關注公眾號:AI約讀社。後臺回覆:concat,獲取完整程式碼。

也歡迎加微信群交流討論。