1. 程式人生 > 程式設計 >PyTorch實現重寫/改寫Dataset並載入Dataloader

PyTorch實現重寫/改寫Dataset並載入Dataloader

前言

眾所周知,Dataset和Dataloder是pytorch中進行資料載入的部件。必須將資料載入後,再進行深度學習模型的訓練。在pytorch的一些案例教學中,常使用torchvision.datasets自帶的MNIST、CIFAR-10資料集,一般流程為:

# 下載並存放資料集
train_dataset = torchvision.datasets.CIFAR10(root="資料集存放位置",download=True)
# load資料
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我們自己的模型訓練中,需要使用非官方自制的資料集。這時應該怎麼辦呢?

我們可以通過改寫torch.utils.data.Dataset中的__getitem____len__來載入我們自己的資料集。
__getitem__獲取資料集中的資料,__len__獲取整個資料集的長度(即個數)。

改寫

採用pytorch官網案例中提供的一個臉部landmark資料集。資料集中含有存放landmark的csv檔案,但是我們在這篇文章中不使用(其實也可以隨便下載一些圖片作資料集來實驗)。

import os
import torch
from skimage import io,transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一個抽象類,我們自己的資料集需要繼承Dataset,然後改寫上述兩個函式:

class ImageLoader(Dataset):
  def __init__(self,file_path,transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 對輸入影象進行預處理,這裡並沒有做,預設為None
    self.image_names = os.listdir(self.file_path) # 檔名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 設定自己存放的資料集位置,並plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 輸出資料集長度(個數),應為71
# print(imageloader.__getitem__(0)) # 以資料形式展示
plt.imshow(imageloader.__getitem__(0)) # 以影象形式展示
plt.show()

得到的圖片輸出:

PyTorch實現重寫/改寫Dataset並載入Dataloader

得到的資料輸出,:

array([[[ 66,59,53],[ 66,...,[ 59,54,48],48]],[153,141,129],[158,146,134],134]]],dtype=uint8)

上面看到dytpe=uint8,實際進行訓練的時候,常常需要更改成float的資料型別。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float() 

改寫完成後,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)載入到Dataloader中,就可以使用了。
下面的程式碼可以試著執行一下,產生的是一模一樣的圖片結果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此這篇關於PyTorch實現重寫/改寫Dataset並載入Dataloader的文章就介紹到這了,更多相關PyTorch重寫/改寫Dataset 內容請搜尋我們以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援我們!