1. 程式人生 > 程式設計 >pytorch製作自己的LMDB資料操作示例

pytorch製作自己的LMDB資料操作示例

本文例項講述了pytorch製作自己的LMDB資料操作。分享給大家供大家參考,具體如下:

前言

記錄下pytorch裡如何使用lmdb的code,自用

製作部分的Code

code就是ASTER裡資料製作部分的程式碼改了點,aster_train.txt裡面就算圖片的完整路徑每行一個,圖片同目錄下有同名的txt,裡面記著jpg的標籤

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
 if imageBin is None:
  return False
 imageBuf = np.fromstring(imageBin,dtype=np.uint8)
 img = cv2.imdecode(imageBuf,cv2.IMREAD_GRAYSCALE)
 imgH,imgW = img.shape[0],img.shape[1]
 if imgH * imgW == 0:
  return False
 return True
def writeCache(env,cache):
 with env.begin(write=True) as txn:
  for k,v in cache.items():
   txn.put(k.encode(),v)
def _is_difficult(word):
 assert isinstance(word,str)
 return not re.match('^[\w]+$',word)
def createDataset(outputPath,imagePathList,labelList,lexiconList=None,checkValid=True):
 """
 Create LMDB dataset for CRNN training.
 ARGS:
   outputPath  : LMDB output path
   imagePathList : list of image path
   labelList   : list of corresponding groundtruth texts
   lexiconList  : (optional) list of lexicon lists
   checkValid  : if true,check the validity of every image
 """
 assert(len(imagePathList) == len(labelList))
 nSamples = len(imagePathList)
 env = lmdb.open(outputPath,map_size=1099511627776)#最大空間1048576GB
 cache = {}
 cnt = 1
 for i in range(nSamples):
  imagePath = imagePathList[i]
  label = labelList[i]
  if len(label) == 0:
   continue
  if not os.path.exists(imagePath):
   print('%s does not exist' % imagePath)
   continue
  with open(imagePath,'rb') as f:
   imageBin = f.read()
  if checkValid:
   if not checkImageIsValid(imageBin):
    print('%s is not a valid image' % imagePath)
    continue
  #資料庫中都是二進位制資料
  imageKey = 'image-%09d' % cnt#9位數不足填零
  labelKey = 'label-%09d' % cnt
  cache[imageKey] = imageBin
  cache[labelKey] = label.encode()
  if lexiconList:
   lexiconKey = 'lexicon-%09d' % cnt
   cache[lexiconKey] = ' '.join(lexiconList[i])
  if cnt % 1000 == 0:
   writeCache(env,cache)
   cache = {}
   print('Written %d / %d' % (cnt,nSamples))
  cnt += 1
 nSamples = cnt-1
 cache['num-samples'] = str(nSamples).encode()
 writeCache(env,cache)
 print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
  with open(txt_path,'r') as fr:
    jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
  txt_content_list=[]
  for jpg in jpg_list:
    label_path=jpg.replace('.jpg','.txt')
    with open(label_path,'r') as fr:
      try:
        str_tmp=fr.readline()
      except UnicodeDecodeError as e:
        print(label_path)
        raise(e)
      txt_content_list.append(str_tmp.strip())
  return jpg_list,txt_content_list
if __name__ == "__main__":
 txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
 lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
 imagePathList,labelList=get_sample_list(txt_path)
 createDataset(lmdb_output_path,labelList)

讀取部分

這裡用的pytorch的dataloader,簡單記錄一下,人比較懶,程式碼就直接抄過來,不整理拆分了,重點看__getitem__

from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image,ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary,labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
 import moxing as mox
 #moxing是一個分散式的框架 跳過
class LmdbDataset(data.Dataset):
 def __init__(self,root,voc_type,max_len,num_samples,transform=None):
  super(LmdbDataset,self).__init__()
  if global_args.run_on_remote:
   dataset_name = os.path.basename(root)
   data_cache_url = "/cache/%s" % dataset_name
   if not os.path.exists(data_cache_url):
    os.makedirs(data_cache_url)
   if mox.file.exists(root):
    mox.file.copy_parallel(root,data_cache_url)
   else:
    raise ValueError("%s not exists!" % root)
   self.env = lmdb.open(data_cache_url,max_readers=32,readonly=True)
  else:
   self.env = lmdb.open(root,readonly=True)
  assert self.env is not None,"cannot create lmdb from %s" % root
  self.txn = self.env.begin()
  self.voc_type = voc_type
  self.transform = transform
  self.max_len = max_len
  self.nSamples = int(self.txn.get(b"num-samples"))
  self.nSamples = min(self.nSamples,num_samples)
  assert voc_type in ['LOWERCASE','ALLCASES','ALLCASES_SYMBOLS','DIGITS']
  self.EOS = 'EOS'
  self.PADDING = 'PADDING'
  self.UNKNOWN = 'UNKNOWN'
  self.voc = get_vocabulary(voc_type,EOS=self.EOS,PADDING=self.PADDING,UNKNOWN=self.UNKNOWN)
  self.char2id = dict(zip(self.voc,range(len(self.voc))))
  self.id2char = dict(zip(range(len(self.voc)),self.voc))
  self.rec_num_classes = len(self.voc)
  self.lowercase = (voc_type == 'LOWERCASE')
 def __len__(self):
  return self.nSamples
 def __getitem__(self,index):
  assert index <= len(self),'index range error'
  index += 1
  img_key = b'image-%09d' % index
  imgbuf = self.txn.get(img_key)
  #由於Image.open需要一個類檔案物件 所以這裡需要把二進位制轉為一個類檔案物件
  buf = six.BytesIO()
  buf.write(imgbuf)
  buf.seek(0)
  try:
   img = Image.open(buf).convert('RGB')
   # img = Image.open(buf).convert('L')
   # img = img.convert('RGB')
  except IOError:
   print('Corrupted image for %d' % index)
   return self[index + 1]
  # reconition labels
  label_key = b'label-%09d' % index
  word = self.txn.get(label_key).decode()
  if self.lowercase:
   word = word.lower()
  ## fill with the padding token
  label = np.full((self.max_len,),self.char2id[self.PADDING],dtype=np.int)
  label_list = []
  for char in word:
   if char in self.char2id:
    label_list.append(self.char2id[char])
   else:
    ## add the unknown token
    print('{0} is out of vocabulary.'.format(char))
    label_list.append(self.char2id[self.UNKNOWN])
  ## add a stop token
  label_list = label_list + [self.char2id[self.EOS]]
  assert len(label_list) <= self.max_len
  label[:len(label_list)] = np.array(label_list)
  if len(label) <= 0:
   return self[index + 1]
  # label length
  label_len = len(label_list)
  if self.transform is not None:
   img = self.transform(img)
  return img,label,label_len

更多關於Python相關內容可檢視本站專題:《Python數學運算技巧總結》、《Python圖片操作技巧總結》、《Python資料結構與演算法教程》、《Python函式使用技巧總結》、《Python字串操作技巧彙總》及《Python入門與進階經典教程》

希望本文所述對大家Python程式設計有所幫助。