pytorch製作自己的LMDB資料操作示例
阿新 • • 發佈:2020-01-09
本文例項講述了pytorch製作自己的LMDB資料操作。分享給大家供大家參考,具體如下:
前言
記錄下pytorch裡如何使用lmdb的code,自用
製作部分的Code
code就是ASTER裡資料製作部分的程式碼改了點,aster_train.txt裡面就算圖片的完整路徑每行一個,圖片同目錄下有同名的txt,裡面記著jpg的標籤
import os import lmdb # install lmdb by "pip install lmdb" import cv2 import numpy as np from tqdm import tqdm import six from PIL import Image import scipy.io as sio from tqdm import tqdm import re def checkImageIsValid(imageBin): if imageBin is None: return False imageBuf = np.fromstring(imageBin,dtype=np.uint8) img = cv2.imdecode(imageBuf,cv2.IMREAD_GRAYSCALE) imgH,imgW = img.shape[0],img.shape[1] if imgH * imgW == 0: return False return True def writeCache(env,cache): with env.begin(write=True) as txn: for k,v in cache.items(): txn.put(k.encode(),v) def _is_difficult(word): assert isinstance(word,str) return not re.match('^[\w]+$',word) def createDataset(outputPath,imagePathList,labelList,lexiconList=None,checkValid=True): """ Create LMDB dataset for CRNN training. ARGS: outputPath : LMDB output path imagePathList : list of image path labelList : list of corresponding groundtruth texts lexiconList : (optional) list of lexicon lists checkValid : if true,check the validity of every image """ assert(len(imagePathList) == len(labelList)) nSamples = len(imagePathList) env = lmdb.open(outputPath,map_size=1099511627776)#最大空間1048576GB cache = {} cnt = 1 for i in range(nSamples): imagePath = imagePathList[i] label = labelList[i] if len(label) == 0: continue if not os.path.exists(imagePath): print('%s does not exist' % imagePath) continue with open(imagePath,'rb') as f: imageBin = f.read() if checkValid: if not checkImageIsValid(imageBin): print('%s is not a valid image' % imagePath) continue #資料庫中都是二進位制資料 imageKey = 'image-%09d' % cnt#9位數不足填零 labelKey = 'label-%09d' % cnt cache[imageKey] = imageBin cache[labelKey] = label.encode() if lexiconList: lexiconKey = 'lexicon-%09d' % cnt cache[lexiconKey] = ' '.join(lexiconList[i]) if cnt % 1000 == 0: writeCache(env,cache) cache = {} print('Written %d / %d' % (cnt,nSamples)) cnt += 1 nSamples = cnt-1 cache['num-samples'] = str(nSamples).encode() writeCache(env,cache) print('Created dataset with %d samples' % nSamples) def get_sample_list(txt_path:str): with open(txt_path,'r') as fr: jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())] txt_content_list=[] for jpg in jpg_list: label_path=jpg.replace('.jpg','.txt') with open(label_path,'r') as fr: try: str_tmp=fr.readline() except UnicodeDecodeError as e: print(label_path) raise(e) txt_content_list.append(str_tmp.strip()) return jpg_list,txt_content_list if __name__ == "__main__": txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt' lmdb_output_path = '/home/gpu-server/project/aster/dataset/train' imagePathList,labelList=get_sample_list(txt_path) createDataset(lmdb_output_path,labelList)
讀取部分
這裡用的pytorch的dataloader,簡單記錄一下,人比較懶,程式碼就直接抄過來,不整理拆分了,重點看__getitem__
from __future__ import absolute_import # import sys # sys.path.append('./') import os # import moxing as mox import pickle from tqdm import tqdm from PIL import Image,ImageFile import numpy as np import random import cv2 import lmdb import sys import six import torch from torch.utils import data from torch.utils.data import sampler from torchvision import transforms from lib.utils.labelmaps import get_vocabulary,labels2strs from lib.utils import to_numpy ImageFile.LOAD_TRUNCATED_IMAGES = True from config import get_args global_args = get_args(sys.argv[1:]) if global_args.run_on_remote: import moxing as mox #moxing是一個分散式的框架 跳過 class LmdbDataset(data.Dataset): def __init__(self,root,voc_type,max_len,num_samples,transform=None): super(LmdbDataset,self).__init__() if global_args.run_on_remote: dataset_name = os.path.basename(root) data_cache_url = "/cache/%s" % dataset_name if not os.path.exists(data_cache_url): os.makedirs(data_cache_url) if mox.file.exists(root): mox.file.copy_parallel(root,data_cache_url) else: raise ValueError("%s not exists!" % root) self.env = lmdb.open(data_cache_url,max_readers=32,readonly=True) else: self.env = lmdb.open(root,readonly=True) assert self.env is not None,"cannot create lmdb from %s" % root self.txn = self.env.begin() self.voc_type = voc_type self.transform = transform self.max_len = max_len self.nSamples = int(self.txn.get(b"num-samples")) self.nSamples = min(self.nSamples,num_samples) assert voc_type in ['LOWERCASE','ALLCASES','ALLCASES_SYMBOLS','DIGITS'] self.EOS = 'EOS' self.PADDING = 'PADDING' self.UNKNOWN = 'UNKNOWN' self.voc = get_vocabulary(voc_type,EOS=self.EOS,PADDING=self.PADDING,UNKNOWN=self.UNKNOWN) self.char2id = dict(zip(self.voc,range(len(self.voc)))) self.id2char = dict(zip(range(len(self.voc)),self.voc)) self.rec_num_classes = len(self.voc) self.lowercase = (voc_type == 'LOWERCASE') def __len__(self): return self.nSamples def __getitem__(self,index): assert index <= len(self),'index range error' index += 1 img_key = b'image-%09d' % index imgbuf = self.txn.get(img_key) #由於Image.open需要一個類檔案物件 所以這裡需要把二進位制轉為一個類檔案物件 buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) try: img = Image.open(buf).convert('RGB') # img = Image.open(buf).convert('L') # img = img.convert('RGB') except IOError: print('Corrupted image for %d' % index) return self[index + 1] # reconition labels label_key = b'label-%09d' % index word = self.txn.get(label_key).decode() if self.lowercase: word = word.lower() ## fill with the padding token label = np.full((self.max_len,),self.char2id[self.PADDING],dtype=np.int) label_list = [] for char in word: if char in self.char2id: label_list.append(self.char2id[char]) else: ## add the unknown token print('{0} is out of vocabulary.'.format(char)) label_list.append(self.char2id[self.UNKNOWN]) ## add a stop token label_list = label_list + [self.char2id[self.EOS]] assert len(label_list) <= self.max_len label[:len(label_list)] = np.array(label_list) if len(label) <= 0: return self[index + 1] # label length label_len = len(label_list) if self.transform is not None: img = self.transform(img) return img,label,label_len
更多關於Python相關內容可檢視本站專題:《Python數學運算技巧總結》、《Python圖片操作技巧總結》、《Python資料結構與演算法教程》、《Python函式使用技巧總結》、《Python字串操作技巧彙總》及《Python入門與進階經典教程》
希望本文所述對大家Python程式設計有所幫助。