行人重識別(8)——程式碼實踐之資料管理器(dataset_manager.py)
阿新 • • 發佈:2020-12-25
!轉載請註明原文地址!——東方旅行者
本文目錄
資料管理器(dataset_manager.py)
一、資料管理器作用
該檔案主要負責指定資料集路徑、處理原始資料集並生成資料索引列表、返回子資料集相關引數(子集行人ID數量,子集圖片數量)。因為Market1501已經劃分好訓練集、測試集與查詢集,所以直接可以根據路徑提取這三個資料集。
二、資料管理器編寫思路
- 指定資料集根目錄路徑
- 分別指定訓練集、測試集、查詢集的路徑
- 通過這三個子資料集路徑獲得子集下所有圖片的地址,通過每個圖片的地址就可以得到行人ID與攝像機ID等資訊,根據這些資訊生成一個索引列表,型別為list,列表中每個元素都是一個三元組(資料圖片地址,行人ID,攝像頭ID),與此同時獲取子資料集相關資訊,如子集行人ID數量,子集圖片數量等引數
- 三個資料集索引列表生成完畢,列印相關引數到控制檯
索引列表如下所示:
[ ('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000451_03.jpg',0,0), ('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000551_01.jpg',0,0), ('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000776_01.jpg',0,0) ]
因為每一個子集中行人ID不一定連續,所以為了便於訓練,一般要對訓練集的行人ID進行重排,便於訓練。所以需要使用一個名稱為pid2label的Map來記錄原始ID與重排ID的對應關係。
三、程式碼
import os
import os.path as osp
import numpy as np
import glob
import re
from IPython import embed
"""
Market1501類用於
1.指定資料集路徑
2.處理原始資料集並生成資料索引列表
3.返回子資料集的相關引數(子集行人ID數量,子集圖片數量)
"""
class Market1501(object):
dataset_dir='data/Market-1501-v15.09.15'#指定資料集路徑
def __init__(self,root='./',**kwargs):
self.dataset_dir=osp.join(root,self.dataset_dir)
self.train_dir=osp.join(self.dataset_dir,'bounding_box_train')#訓練集
self.gallery_dir=osp.join(self.dataset_dir,'bounding_box_test')#測試集
self.query_dir=osp.join(self.dataset_dir,'query')#查詢集
train, num_train_pids, num_train_imgs=self._process_dir(self.train_dir,relabel=True)
query, num_query_pids, num_query_imgs=self._process_dir(self.query_dir,relabel=False)
gallery, num_gallery_pids, num_gallery_imgs=self._process_dir(self.gallery_dir,relabel=False)
num_total_pids=num_train_pids+num_query_pids
num_total_imgs=num_train_imgs+num_query_imgs
print("=> Market1501 loaded")
print("------------------------------------------------------------------------")
print(" subset: train \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_train_pids,num_train_imgs))
print(" subset: query \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_query_pids,num_query_imgs))
print(" subset: gallery \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_gallery_pids,num_gallery_imgs))
print("------------------------------------------------------------------------")
print(" total \t\t\t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_total_pids,num_total_imgs))
print("------------------------------------------------------------------------")
self.train=train
self.query=query
self.gallery=gallery
self.num_train_pids=num_train_pids
self.num_query_pids=num_query_pids
self.num_gallery_pids=num_gallery_pids
def _process_dir(self,dir_path,relabel=False):
img_paths=glob.glob(osp.join(dir_path,'*.jpg'))
pid_container=set()
for img_path in img_paths:
pid=int(img_path.split("\\")[-1].split("_")[0])
if pid==-1:continue
pid_container.add(pid)
pid2label={pid:label for label,pid in enumerate(pid_container)}
dataset=[]
for img_path in img_paths:
str_list=img_path.split("\\")[-1].split("_")
pid=int(str_list[0])
cid=int(str_list[1][1:2])
if pid==-1:continue
assert 0<=pid <=1501
assert 1<=cid<=6
cid+=-1
if relabel:
pid=pid2label[pid]
dataset.append((img_path,pid,cid))
num_pids=len(pid_container)
num_imgs=len(img_paths)
#返回一個數據為三元組(圖片地址,行人ID,攝像機ID)的索引列表形式的資料集,行人ID數量,圖片數量
return dataset, num_pids, num_imgs
if __name__=='__main__':
data=Market1501()