pytorch使用(一)處理並載入自己的資料
pytorch使用(一)資料處理
個人認為,資料處理或許是在完成一篇論文中最耗費時間的,特別是大多情況下,需要在很多個庫上做實驗。
pytorch官方支援很多庫,使用torchvision來完成資料的處理,點這裡可以看到支援的庫並不是很多。在這裡,我將結合一個例項說明如何使用pytorch來處理自己的資料,任務是一個分析雙臂運動的,檢測6個關節點的運動。輸入是連續三幀的檢測結果以及計算的光流,也就是$3*6+2*2=22$
張heatmap,輸出是中間幀的檢測結果,也就是6張heatmap。
把原始資料處理為模型使用的資料需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分別可以理解為資料處理格式的定義、資料處理和資料載入。
1. 資料預處理torchvision.transforms
pytorch使用torchvision.transforms實現資料的預處理,包括中心化(torchvision.transforms.CenterCrop)、隨機剪下(torchvision.transforms.RandomCrop)、正則化、圖片變為Tensor、tensor變為圖片等,建議整體瀏覽一下這一部分的官方手冊,非常有用,資料處理很方便。
先轉換為張量,然後正則化:
import torchvision.transforms as transforms
transform = transforms.Compose ([transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
#img = transform(img)
2. 資料讀取,構建Dataset子類
如果想要使用自己的資料,則必須自己構建一個torch.utils.data.Dataset的子類去讀取資料。我們的將資料列表放在train.txt
和test.txt
中,將不同型別的資料的路徑放在path.txt
中,所以在類的init函式中有path_file和 list_file兩個變數
在定義torch.utils.data.Dataset的子類時,必須過載的兩個函式是len和getitem:
- len返回資料集的大小
- getitem實現資料集的下標索引,返回對應的影象和標記(不一定非得返回影象和標記,返回元組的長度可以是任意長,這由網路需要的資料決定)。
末尾有自己寫的一個Dataset子類的定義檔案。
3. 資料載入
torch.utils.data.DataLoader()
函式,合成數據並且提供迭代訪問。主要由兩部分組成:
- dataset(Dataset)。輸入載入的資料,就是上面的MyDataset
的實現。
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等引數,介紹幾個比較常用的,這些在官方網站都有:
- batch-size。樣本每個batch的大小,預設為1。
- shuffle。是否打亂資料,預設為False。
- num_workers。資料分為幾個執行緒處理預設為0。
- sampler。定義一個方法來繪製樣本資料,如果定義該方法,則不能使用shuffle。預設為False
使用:
import torch
from datagen import MyDataset
trainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)
testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)
以下是定義class MyDataset
檔案datagen.py
, 其中有__init__(self, path_file, list_file,numJoints,type)
、__getitem__(self, idx)
、__len__(self)
三個函式,__getitem__
返回一個(22,256,256)的輸入和一個(6,256,256)的標籤。
'''
Load data
'''
import numpy as np
from PIL import Image
#import cv2
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
class MyDataset(data.Dataset):
def __init__(self, path_file, list_file,numJoints,type):
'''
Args:
path_file: (str) heatmap and optical file location
list_file: (str) path to index file.
numJoints: (int) number of joints
type: (boolean) use pose flow(true) or optical flow(false)
'''
self.numJoints = numJoints
# read heatmap and optical path
with open(path_file) as f:
paths = f.readlines()
for path in paths:
splited = path.strip().split()
if splited[0]=='resPath':
self.resPath = splited[1]
elif splited[0]=='gtPath':
self.gtPath = splited[1]
elif splited[0]=='opticalFlowPath':
self.opticalFlowPath = splited[1]
elif splited[0]=='poseFlowPath':
self.poseFlowPath = splited[1]
if type:
self.flowPath = self.poseFlowPath
else:
self.flowPath = self.opticalFlowPath
#read list
with open(list_file) as f:
self.list = f.readlines()
self.num_samples = len(self.list)
def __getitem__(self, idx):
'''
load heatmaps and optical flow and encode it to a 22 channels input and 6 channels output
:param idx: (int) image index
:return:
input: a 22 channel input which integrate 2 optical flow and heatmaps of 3 image
output: the ground truth
'''
input = []
output = []
# load heatmaps of 3 image
for im in range(3):
for map in range(6):
curResPath = self.resPath + self.list[idx].rstrip('\n') + str(im + 1) + '/' + str(map + 1) + '.bmp'
heatmap = Image.open(curResPath)
heatmap.load()
heatmap = np.asarray(heatmap, dtype='float') / 255
input.append(heatmap)
# load 2 flow
for flow in range(2):
curFlowXPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowx/' + str(flow + 1) + '.jpg'
flowX = Image.open(curFlowXPath)
flowX.load()
flowX = np.asarray(flowX, dtype='float')
curFlowYPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowy/' + str(flow + 1) + '.jpg'
flowY = Image.open(curFlowYPath)
flowY.load()
flowY = np.asarray(flowY, dtype='float')
input.append(flowX)
input.append(flowY)
# load groundtruth
for map in range(6):
curgtPath = self.resPath + self.list[idx].rstrip('\n') + str(2) + '/' + str(map + 1) + '.bmp'
heatmap = Image.open(curResPath)
heatmap.load()
heatmap = np.asarray(heatmap, dtype='float') / 255
output.append(heatmap)
input = torch.Tensor(input)
output = torch.Tensor(output)
return input,output
def __len__(self):
return self.num_samples