中文文字識別 FSNS格式tfrecord生成
最近,想使用谷歌的Attention OCR做中文文字識別,專案github地址:https://github.com/A-bone1/Attention-ocr-Chinese-Version,中文介紹可參考CSDN部落格:https://blog.csdn.net/qq_40003316/article/details/80062023。
研究後發現該模型的訓練資料需要提供FSNS格式的訓練資料,而官方也沒有給出相關的文件,只給了一個stackoverflow的連結https://stackoverflow.com/a/44461910/743658,可是說的也不清楚。所以自己參考網上的一些辦法,寫了一個生成FSNS格式tfrecord的小程式碼。github地址為:https://github.com/A-bone1/FSNS-tfrecord-generate。
FSNS的具體格式在這篇論文有說:https://arxiv.org/pdf/1702.03970.pdf
但是,我們只需關心表四即可:
image/format表示圖片的格式,是‘png’ ,如果你生的tfrecord是使用jpg格式,可改成‘raw’
image/encoded 表示圖片的具體內容,佔用一個string,以‘png’的格式編碼
iamge/class表示圖片真實的類別id,是37個int64資料,每一個int64對應一個字元編碼,具體的對映方式在charset_size=134.txt檔案中,要生成自己的資料需要自己建立類似的字典,如我自己建立的包含5400箇中文的dic.txt。
image/unpadded_class 表示圖片在沒有被填充之前真實的id。
image/width:表示圖片的畫素的寬度
image/orig_width:表示圖片在沒有填充之前畫素的寬度
image/height:表示圖片的畫素的高度,在tensorflow程式碼中,這一部分並沒有寫入程式碼,因為圖片高度固定為150
image/test:佔用一個string,是使用UTF-8編碼的真實的字元形式的標記
下面直接上程式碼:(上傳的程式碼是將jpg圖片直接儲存為tfrecord,速度較快,如果讀者想生成png編碼的tfrecord,可以參考我的github。
from random import shuffle import numpy as np import glob import tensorflow as tf import cv2 import sys import os import PIL.Image as Image def encode_utf8_string(text, length, dic, null_char_id=5462): char_ids_padded = [null_char_id]*length char_ids_unpadded = [null_char_id]*len(text) for i in range(len(text)): hash_id = dic[text[i]] char_ids_padded[i] = hash_id char_ids_unpadded[i] = hash_id return char_ids_padded, char_ids_unpadded def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) dict={} with open('dic.txt', encoding="utf") as dict_file: for line in dict_file: (key, value) = line.strip().split('\t') dict[value] = int(key) print((dict)) image_path = 'data/*/*.jpg' addrs_image = glob.glob(image_path) label_path = 'data/*/*.txt' addrs_label = glob.glob(label_path) print(len(addrs_image)) print(len(addrs_label)) tfrecord_writer = tf.python_io.TFRecordWriter("tfexample_train") for j in range(0,int(len(addrs_image))): # 這是寫入操作視覺化處理 print('Train data: {}/{}'.format(j,int(len(addrs_image)))) sys.stdout.flush() img = Image.open(addrs_image[j]) img = img.resize((600, 150), Image.ANTIALIAS) np_data = np.array(img) image_data = img.tobytes() for text in open(addrs_label[j], encoding="utf"): char_ids_padded, char_ids_unpadded = encode_utf8_string( text=text, dic=dict, length=37, null_char_id=5462) example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded': _bytes_feature(image_data), 'image/format': _bytes_feature(b"raw"), 'image/width': _int64_feature([np_data.shape[1]]), 'image/orig_width': _int64_feature([np_data.shape[1]]), 'image/class': _int64_feature(char_ids_padded), 'image/unpadded_class': _int64_feature(char_ids_unpadded), 'image/text': _bytes_feature(bytes(text, 'utf-8')), # 'height': _int64_feature([crop_data.shape[0]]), } )) tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close() sys.stdout.flush()
原文地址: