1. 程式人生 > >中文文字識別 FSNS格式tfrecord生成

中文文字識別 FSNS格式tfrecord生成

最近,想使用谷歌的Attention OCR做中文文字識別,專案github地址:https://github.com/A-bone1/Attention-ocr-Chinese-Version,中文介紹可參考CSDN部落格:https://blog.csdn.net/qq_40003316/article/details/80062023。
        研究後發現該模型的訓練資料需要提供FSNS格式的訓練資料,而官方也沒有給出相關的文件,只給了一個stackoverflow的連結https://stackoverflow.com/a/44461910/743658,可是說的也不清楚。所以自己參考網上的一些辦法,寫了一個生成FSNS格式tfrecord的小程式碼。github地址為:https://github.com/A-bone1/FSNS-tfrecord-generate。

      FSNS的具體格式在這篇論文有說:https://arxiv.org/pdf/1702.03970.pdf

      但是,我們只需關心表四即可:


image/format表示圖片的格式,是‘png’ ,如果你生的tfrecord是使用jpg格式,可改成‘raw’
image/encoded 表示圖片的具體內容,佔用一個string,以‘png’的格式編碼 
iamge/class表示圖片真實的類別id,是37個int64資料,每一個int64對應一個字元編碼,具體的對映方式在charset_size=134.txt檔案中,要生成自己的資料需要自己建立類似的字典,如我自己建立的包含5400箇中文的dic.txt。
image/unpadded_class 表示圖片在沒有被填充之前真實的id。 
image/width:表示圖片的畫素的寬度 
image/orig_width:表示圖片在沒有填充之前畫素的寬度 
image/height:表示圖片的畫素的高度,在tensorflow程式碼中,這一部分並沒有寫入程式碼,因為圖片高度固定為150 
image/test:佔用一個string,是使用UTF-8編碼的真實的字元形式的標記
 
        下面直接上程式碼:(上傳的程式碼是將jpg圖片直接儲存為tfrecord,速度較快,如果讀者想生成png編碼的tfrecord,可以參考我的github。
 

from random import shuffle
import numpy as np
import glob
import tensorflow as tf
import cv2
import sys
import os
import PIL.Image as Image
 
 
def encode_utf8_string(text, length, dic, null_char_id=5462):
    char_ids_padded = [null_char_id]*length
    char_ids_unpadded = [null_char_id]*len(text)
    for i in range(len(text)):
        hash_id = dic[text[i]]
        char_ids_padded[i] = hash_id
        char_ids_unpadded[i] = hash_id
    return char_ids_padded, char_ids_unpadded
 
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
 
dict={}
with open('dic.txt', encoding="utf") as dict_file:
    for line in dict_file:
        (key, value) = line.strip().split('\t')
        dict[value] = int(key)
print((dict))
 
image_path = 'data/*/*.jpg'
addrs_image = glob.glob(image_path)
 
label_path = 'data/*/*.txt'
addrs_label = glob.glob(label_path)
 
print(len(addrs_image))
print(len(addrs_label))
 
tfrecord_writer  = tf.python_io.TFRecordWriter("tfexample_train") 
for j in range(0,int(len(addrs_image))):
    
 
            # 這是寫入操作視覺化處理
    print('Train data: {}/{}'.format(j,int(len(addrs_image))))
    sys.stdout.flush()
 
    img = Image.open(addrs_image[j])
 
    img = img.resize((600, 150), Image.ANTIALIAS)
    np_data = np.array(img)
    image_data = img.tobytes()
    for text in open(addrs_label[j], encoding="utf"):
                 char_ids_padded, char_ids_unpadded = encode_utf8_string(
                            text=text,
                            dic=dict,
                            length=37,
                            null_char_id=5462)
 
 
 
    example = tf.train.Example(features=tf.train.Features(
                        feature={
                            'image/encoded': _bytes_feature(image_data),
                            'image/format': _bytes_feature(b"raw"),
                            'image/width': _int64_feature([np_data.shape[1]]),
                            'image/orig_width': _int64_feature([np_data.shape[1]]),
                            'image/class': _int64_feature(char_ids_padded),
                            'image/unpadded_class': _int64_feature(char_ids_unpadded),
                            'image/text': _bytes_feature(bytes(text, 'utf-8')),
                            # 'height': _int64_feature([crop_data.shape[0]]),
                        }
                    ))
    tfrecord_writer.write(example.SerializeToString())
tfrecord_writer.close()
 
sys.stdout.flush()

原文地址: