【深度學習】PyTorch Dataset類的使用與例項分析
Dataset類
介紹
當我們得到一個數據集時,Dataset類可以幫我們提取我們需要的資料,我們用子類繼承Dataset類,我們先給每個資料一個編號(idx),在後面的神經網路中,初始化Dataset子類例項後,就可以通過這個編號去例項物件中讀取相應的資料,會自動呼叫__getitem__方法,同時子類物件也會獲取相應真實的Label(人為去複寫即可)
Dataset類的作用:提供一種方式去獲取資料及其對應的真實Label
在Dataset類的子類中,應該有以下函式以實現某些功能:
- 獲取每一個數據及其對應的Label
- 統計資料集中的資料數量
關於2,神經網路經常需要對一個數據迭代多次,只有知道當前有多少個數據,進行訓練時才知道要訓練多少次,才能把整個資料集迭代完
Dataset官方文件解讀
首先看一下Dataset的官方文件解釋
匯入Dataset類:
from torch.utils.data import Dataset
我們可以通過在Jupyter中檢視官方文件
from torch.utils.data import Dataset
help(Dataset)
輸出:
Help on class Dataset in module torch.utils.data.dataset: class Dataset(typing.Generic) | An abstract class representing a :class:`Dataset`. | | All datasets that represent a map from keys to data samples should subclass | it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a | data sample for a given key. Subclasses could also optionally overwrite | :meth:`__len__`, which is expected to return the size of the dataset by many | :class:`~torch.utils.data.Sampler` implementations and the default options | of :class:`~torch.utils.data.DataLoader`. | | .. note:: | :class:`~torch.utils.data.DataLoader` by default constructs a index | sampler that yields integral indices. To make it work with a map-style | dataset with non-integral indices/keys, a custom sampler must be provided. | | Method resolution order: | Dataset | typing.Generic | builtins.object | | Methods defined here: | | __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]' | | __getattr__(self, attribute_name) | | __getitem__(self, index) -> +T_co | | ---------------------------------------------------------------------- | Class methods defined here: | | register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta | | register_function(function_name, function) from typing.GenericMeta | | ---------------------------------------------------------------------- | Data descriptors defined here: | | __dict__ | dictionary for instance variables (if defined) | | __weakref__ | list of weak references to the object (if defined) | | ---------------------------------------------------------------------- | Data and other attributes defined here: | | __abstractmethods__ = frozenset() | | __annotations__ = {'functions': typing.Dict[str, typing.Callable]} | | __args__ = None | | __extra__ = None | | __next_in_mro__ = <class 'object'> | The most base type | | __orig_bases__ = (typing.Generic[+T_co],) | | __origin__ = None | | __parameters__ = (+T_co,) | | __tree_hash__ = -9223371872509358054 | | functions = {'concat': functools.partial(<function Dataset.register_da... | | ---------------------------------------------------------------------- | Static methods inherited from typing.Generic: | | __new__(cls, *args, **kwds) | Create and return a new object. See help(type) for accurate signature.
還有一種方式獲取官方文件資訊:
Dataset??
輸出:
Init signature: Dataset(*args, **kwds) Source: class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ functions: Dict[str, Callable] = {} def __getitem__(self, index) -> T_co: raise NotImplementedError def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': return ConcatDataset([self, other]) # No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # in pytorch/torch/utils/data/sampler.py def __getattr__(self, attribute_name): if attribute_name in Dataset.functions: function = functools.partial(Dataset.functions[attribute_name], self) return function else: raise AttributeError @classmethod def register_function(cls, function_name, function): cls.functions[function_name] = function @classmethod def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False): if function_name in cls.functions: raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name)) def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs): result_pipe = cls(source_dp, *args, **kwargs) if isinstance(result_pipe, Dataset): if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe): if function_name not in UNTRACABLE_DATAFRAME_PIPES: result_pipe = result_pipe.trace_as_dataframe() return result_pipe function = functools.partial(class_function, cls_to_register, enable_df_api_tracing) cls.functions[function_name] = function File: d:\environment\anaconda3\envs\py-torch\lib\site-packages\torch\utils\data\dataset.py Type: GenericMeta Subclasses: Dataset, IterableDataset, Dataset, TensorDataset, ConcatDataset, Subset, Dataset, Subset, Dataset, IterableDataset[+T_co], ...
其中我們可以看到:
"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
"""
以上內容顯示:
該類是一個抽象類,所有的資料集想要在資料與標籤之間建立對映,都需要繼承這個類,所有的子類都需要重寫__getitem__
方法,該方法根據索引值獲取每一個數據並且獲取其對應的Label,子類也可以重寫__len__
方法,返回資料集的size大小
例項:GetData類
準備工作
首先我們建立一個類,類名為GetData,這個類要繼承Dataset類
class GetData(Dataset):
一般在類中首先需要寫的是__init__
方法,此方法用於物件例項化,通常用來提供類中需要使用的變數,可以先不寫
class GetData(Dataset):
def __init__(self):
pass
我們可以先寫__getitem__
方法:
class GetData(Dataset):
def __init__(self):
pass
def __getitem__(self, idx): # 預設是item,但常改為idx,是index的縮寫
pass
其中,idx是index的簡稱,就是一個編號,以便以後資料集獲取後,我們使用索引編號訪問每個資料
在實現GetData類之前,我們首先需要解決的問題就是如何讀取一個影象資料,通常我們使用PIL來讀取
PIL獲取影象資料
我們使用PIL來讀取資料,它提供一個Image模組,可以讓我們提取影象資料,我們先匯入這個模組
from PIL import Image
我們可以在Python Console中看看如何使用 Image
在Python Console中,輸入程式碼:
from PIL import Image
將資料集放入專案資料夾,我們需要獲取圖片的絕對路徑,選中具體的圖片,右鍵選擇Copy Path,然後選擇 Absolute path(快捷鍵:Ctrl + Shift + C)
img_path = "D:\\DeepLearning\\dataset\\train\\ants\\0013035.jpg"
在Windows下,路徑分割需要是
\\
,來表示轉譯也可以在字串前面加
r
防轉譯
使用Image的open方法讀取圖片:
img = Image.open(img_path)
可以在Python控制檯看到讀取出來的 img,是一個JpegImageFile類的物件
在圖中,可以看到這個物件的一些屬性,比如size
我們檢視這個屬性的內容,輸入以下程式碼:
img.size
輸出:
(768, 512)
我們可以看到此圖的寬是768,高是512,__len__
表示的是這個size元組的長度,有兩個值,所以為 2
show方法顯示圖片:
img.show()
獲取圖片的檔名
從資料集路徑中,獲取所有檔案的名字,儲存到一個列表中
一個簡單的例子(在Python Console中):
我們需要藉助os模組
import os
dir_path = "dataset/train/ants_image"
img_path_list = os.listdir(dir_path)
listdir方法會將路徑下的所有檔名(包括字尾名)組成一個列表
我們可以使用索引去訪問列表中的每個檔名
img_path_list[0]
Out[14]: '0013035.jpg'
構建資料集路徑
我們需要搭建資料集的路徑表示,一個根目錄路徑和一個具體的子目錄路徑,以作為不同資料集的區分
一個簡單的案例,在Python Console中輸入:
root_dir = "dataset/train"
child_dir = "ants_image"
我們使用os.path.join
方法,將兩個路徑拼接起來,就得到了ants子資料集的相對路徑
path = os.path.join(root_dir, child_dir)
path的值此時是:
path={str}'dataset/train\\ants_image'
我們有了這個資料集的路徑後,就可以使用之前所講的listdir方法,獲取這個路徑中所有檔案的檔名,儲存到一個列表中
img_path_list = os.listdir(path)
idx = 0
img_path_list[idx]
Out[21]: '0013035.jpg'
可以看到結果與我們之前的小案例是一樣的
有了具體的名字,我們還可以將這個檔名與路徑進行組合,然後使用PIL獲取具體的影象img物件
img_name = img_path_list[idx]
img_item_path = os.path.join(root_dir, child_dir, img_name)
img = Image.open(img_item_path)
在掌握瞭如何組裝路徑、獲取路徑中的檔名以及獲取具體影象物件後,我們可以完善我們的__init__
與__getitem__
方法了
完善__init__方法
在init中為啥使用self:一個函式中的變數是不能拿到另外一個函式中使用的,self可以當做類中的全域性變數
class GetData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
很簡單,就是接收例項化時傳入的引數:獲取根目錄路徑、子目錄路徑
然後將兩個路徑進行組合,就得到了目標資料集的路徑
我們將這個路徑作為引數傳入listdir函式,從而讓img_path_list中儲存該目錄下所有檔名(包含字尾名)
此時通過索引就可以輕鬆獲取每個檔名
接下來,我們要使用這些初始化的資訊去獲取其中的每一個圖片的JpegImageFile物件
完善__getitem__方法
我們在初始化中,已經通過組裝資料集路徑,進而通過listdir方法獲取了資料集中每個檔案的檔名,存入了一個列表中。
在__getitem__方法中,預設會有一個 item 引數,常命名為 idx,這個引數是一個索引編號,用於對我們初始化中得到的檔名列表進行索引訪問,我們就得到了具體的檔名,然後與根目錄、子目錄再次組裝,得到具體資料的相對路徑,我們可以通過這個路徑獲取到索引編號對應的資料物件本身。
這樣巧妙的讓索引與資料集中的具體資料對應了起來
def __getitem__(self, idx):
img_name = self.img_path_list[idx] # 從檔名列表中獲取了檔名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 組裝路徑,獲得了圖片具體的路徑
獲取了具體的影象路徑後,我們需要使用PIL讀取這個影象
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
此處img是一個JpegImageFile物件,label是一個字串
自此,這個函式我們就實現完成了
以後使用這個類進行例項化時,傳入的引數是根目錄路徑,以及對應的label名,我們就可以得到一個GetData物件。
有了這個GetData物件後,我們可以直接使用索引來獲取具體的影象物件(類:JpegImageFile),因為__getitem__方法已經幫我們實現了,我們只需要使用索引即可呼叫__getitem__方法,會返回我們根據索引提取到的對應資料的影象物件以及其label
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GetData(root_dir, bees_label_dir)
img1, label1 = ants_dataset[0] # 返回一個元組,返回值是__getitem__方法的返回值
img2, label2 = bees_dataset[0]
完善__len__方法
__len__實現很簡單
主要功能是獲取資料集的長度,由於我們在初始化中已經獲取了所有檔名的列表,所以只需要知道這個列表的長度,就知道了有多少個檔案,也就是知道了有多少個具體的資料
def __len__(self):
return len(self.img_path_list)
組合資料集
我們還可以將兩個資料集物件進行組合,組合成一個大的資料集物件
train_dataset = ants_dataset + bees_dataset
我們看看這三個資料集物件的大小(在python Console中):
len1 = len(ants_dataset)
len2 = len(bees_dataset)
len3 = len(train_dataset)
輸出:
124
121
245
我們可以看到剛好 $$124 + 121 = 245$$
而對這個組合的資料集的訪問也很有意思,也同樣是使用索引,0 ~ 123 都是ants資料集的內容,124 - 244 都是bees資料集的內容
img1, label1 = train_dataset[123]
img1.show()
img2, label2 = train_dataset[124]
img2.show()
完整程式碼
from torch.utils.data import Dataset
from PIL import Image
import os
class GetData(Dataset):
# 初始化為整個class提供全域性變數,為後續方法提供一些量
def __init__(self, root_dir, label_dir):
# self
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path_list[idx] # 只獲取了檔名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每個圖片的位置
# 讀取圖片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GeyData(root_dir, bees_label_dir)
img, lable = ants_dataset[0] # 返回一個元組,返回值就是__getitem__的返回值
# 獲取整個訓練集,就是對兩個資料集進行了拼接
train_dataset = ants_dataset + bees_dataset
len1 = len(ants_dataset) # 124
len2 = len(bees_dataset) # 121
len = len(train_dataset) # 245
img1, label1 = train_dataset[123] # 獲取的是螞蟻的最後一個
img2, label2 = train_dataset[124] # 獲取的是蜜蜂第一個