1. 程式人生 > 其它 >行人重識別(8)——程式碼實踐之資料管理器(dataset_manager.py)

行人重識別(8)——程式碼實踐之資料管理器(dataset_manager.py)

技術標籤:行人重識別演算法計算機視覺行人重識別

!轉載請註明原文地址!——東方旅行者

本文目錄

資料管理器(dataset_manager.py)

一、資料管理器作用

該檔案主要負責指定資料集路徑處理原始資料集並生成資料索引列表返回子資料集相關引數(子集行人ID數量,子集圖片數量)。因為Market1501已經劃分好訓練集、測試集與查詢集,所以直接可以根據路徑提取這三個資料集。

二、資料管理器編寫思路

  1. 指定資料集根目錄路徑
  2. 分別指定訓練集、測試集、查詢集的路徑
  3. 通過這三個子資料集路徑獲得子集下所有圖片的地址,通過每個圖片的地址就可以得到行人ID與攝像機ID等資訊,根據這些資訊生成一個索引列表,型別為list,列表中每個元素都是一個三元組(資料圖片地址,行人ID,攝像頭ID),與此同時獲取子資料集相關資訊,如子集行人ID數量,子集圖片數量等引數
  4. 三個資料集索引列表生成完畢,列印相關引數到控制檯

索引列表如下所示:

[
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000451_03.jpg',0,0),
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000551_01.jpg',0,0),
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000776_01.jpg',0,0)
]

因為每一個子集中行人ID不一定連續,所以為了便於訓練,一般要對訓練集的行人ID進行重排,便於訓練。所以需要使用一個名稱為pid2label的Map來記錄原始ID與重排ID的對應關係。

三、程式碼

import os
import os.path as osp
import numpy as np
import glob
import re
from IPython import embed
"""
Market1501類用於
1.指定資料集路徑
2.處理原始資料集並生成資料索引列表
3.返回子資料集的相關引數(子集行人ID數量,子集圖片數量)
"""
class Market1501(object): dataset_dir='data/Market-1501-v15.09.15'#指定資料集路徑 def __init__(self,root='./',**kwargs): self.dataset_dir=osp.join(root,self.dataset_dir) self.train_dir=osp.join(self.dataset_dir,'bounding_box_train')#訓練集 self.gallery_dir=osp.join(self.dataset_dir,'bounding_box_test')#測試集 self.query_dir=osp.join(self.dataset_dir,'query')#查詢集 train, num_train_pids, num_train_imgs=self._process_dir(self.train_dir,relabel=True) query, num_query_pids, num_query_imgs=self._process_dir(self.query_dir,relabel=False) gallery, num_gallery_pids, num_gallery_imgs=self._process_dir(self.gallery_dir,relabel=False) num_total_pids=num_train_pids+num_query_pids num_total_imgs=num_train_imgs+num_query_imgs print("=> Market1501 loaded") print("------------------------------------------------------------------------") print(" subset: train \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_train_pids,num_train_imgs)) print(" subset: query \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_query_pids,num_query_imgs)) print(" subset: gallery \t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_gallery_pids,num_gallery_imgs)) print("------------------------------------------------------------------------") print(" total \t\t\t| num_id: {:5d} \t| num_imgs:{:8d} ".format(num_total_pids,num_total_imgs)) print("------------------------------------------------------------------------") self.train=train self.query=query self.gallery=gallery self.num_train_pids=num_train_pids self.num_query_pids=num_query_pids self.num_gallery_pids=num_gallery_pids def _process_dir(self,dir_path,relabel=False): img_paths=glob.glob(osp.join(dir_path,'*.jpg')) pid_container=set() for img_path in img_paths: pid=int(img_path.split("\\")[-1].split("_")[0]) if pid==-1:continue pid_container.add(pid) pid2label={pid:label for label,pid in enumerate(pid_container)} dataset=[] for img_path in img_paths: str_list=img_path.split("\\")[-1].split("_") pid=int(str_list[0]) cid=int(str_list[1][1:2]) if pid==-1:continue assert 0<=pid <=1501 assert 1<=cid<=6 cid+=-1 if relabel: pid=pid2label[pid] dataset.append((img_path,pid,cid)) num_pids=len(pid_container) num_imgs=len(img_paths) #返回一個數據為三元組(圖片地址,行人ID,攝像機ID)的索引列表形式的資料集,行人ID數量,圖片數量 return dataset, num_pids, num_imgs if __name__=='__main__': data=Market1501()

四、測試結果

data_manager測試結果