讀取tfrecord,並寫入h5檔案
阿新 • • 發佈:2022-03-09
import tfrecord as tfr import h5py import os,sys import numpy as np import glob import pandas as pd from tqdm import tqdm class TfrecordWorker(): def __init__(self,tfr_list): self.info = {"label":[],"typee":[],"shape":[]} self.data_dir = "raw_data" self.tfr_list = tfr_list self.tfr_description = self._parse_description("label_type.csv") loader = tfr.tfrecord_loader(self.tfr_list[0], None, self.tfr_description ) for record in loader: for key in record.keys(): self.info['label'].append(key) self.info['typee'].append(type(record[key][0])) self.info['shape'].append(record[key].shape) self.attr_size = len(self.info['label']) self.data_size = len(self.tfr_list) print(f"總共有{self.attr_size}個屬性") print(f"總共有{self.data_size}個tfrecord檔案") def create_h5f(self, h5path="./data.h5"): self.h5f = h5py.File(h5path, 'w') self.dset = {} for i in range(self.attr_size): label = self.info["label"][i] typee = self.info["typee"][i] shape = self.info["shape"][i] self.dset[label] = self.h5f.create_dataset(label, shape=[self.data_size, shape[0]], compression=None, dtype=typee) self.dset["name"] = self.h5f.create_dataset("name", shape=[self.data_size], compression=None, dtype=h5py.special_dtype(vlen=str)) def write_h5f(self): for idx,tfr_path in tqdm(enumerate(self.tfr_list)): self._write_one_item(tfr_path, idx) # if idx>5: # break def close_h5f(self): self.h5f.close() def _write_one_item(self, tfr_path, idx): loader = tfr.tfrecord_loader(tfr_path, None, self.tfr_description ) for record in loader: for key in record.keys(): content = record[key] self.dset[key][idx] = content self.dset["name"][idx] = tfr_path.split("/")[-1] def _parse_description(self, csv_path): label_type = pd.read_csv(csv_path, usecols=["label","type"]) description = {} for _, row in label_type.iterrows(): description[str(row['label']).strip()] = str(row['type']).strip() return description def start(files, savename): worker = TfrecordWorker(files) worker.create_h5f(savename) worker.write_h5f() worker.close_h5f() start(glob.glob("raw_data/*fold0*.tfrecord"),"fold0.h5") start(glob.glob("raw_data/*fold1*.tfrecord"),"fold1.h5") start(glob.glob("raw_data/*fold2*.tfrecord"),"fold2.h5") start(glob.glob("raw_data/*fold3*.tfrecord"),"fold3.h5")
f = h5py.File('fold0.h5', 'r')
print('--iterms: ', len(f.keys()), f.keys())
name = f['name']
print(name[:])