1. 程式人生 > >TensorFlow與中文手寫漢字識別

TensorFlow與中文手寫漢字識別

Goal

本文目標是利用TensorFlow做一個簡單的影象分類器,在比較大的資料集上,儘可能高效地做影象相關處理,從Train,Validation到Inference,是一個比較基本的Example, 從一個基本的任務學習如果在TensorFlow下做高效地影象讀取,基本的影象處理,整個專案很簡單,但其中有一些trick,在實際專案當中有很大的好處, 比如絕對不要一次讀入所有的 的資料到記憶體(儘管在Mnist這類級別的例子上經常出現)…

最開始看到是這篇blog裡面的TensorFlow練習22: 手寫漢字識別, 但是這篇文章只用了140訓練與測試,試了下程式碼 很快,但是當擴充套件到所有的時,發現32g的記憶體都不夠用,這才注意到原文中都是用numpy,會先把所有的資料放入到記憶體,但這個不必須的,無論在MXNet還是TensorFlow中都是不必 須的,MXNet使用的是DataIter,會在程式執行的過程中非同步讀取資料,TensorFlow也是這樣的,TensorFlow封裝了高階的api,用來做資料的讀取,比如TFRecord,還有就是從filenames中讀取, 來非同步讀取檔案,然後做shuffle batch,再feed到模型的Graph中來做模型引數的更新。具體在tf如何做資料的讀取可以看看

reading data in tensorflow

這裡我會拿到所有的資料集來做訓練與測試,算作是對斗大的熊貓上面那篇文章的一個擴充套件。

Batch Generate

資料集來自於中科院自動化研究所,感謝分享精神!!!具體下載:

wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip

解壓後發現是一些gnt檔案,然後用了斗大的熊貓

裡面的程式碼,將所有檔案都轉化為對應label目錄下的所有png的圖片。(注意在HWDB1.1trn_gnt.zip解壓後是alz檔案,需要再次解壓 我在mac沒有找到合適的工具,windows上有alz的解壓工具)。

import os
import numpy as np
importstructfrom PIL importImage


data_dir ='../data'
train_data_dir = os.path.join(data_dir,'HWDB1.1trn_gnt')
test_data_dir = os.path.join(data_dir,'HWDB1.1tst_gnt'
)def read_from_gnt_dir(gnt_dir=train_data_dir):def one_file(f): header_size =10whileTrue: header = np.fromfile(f, dtype='uint8', count=header_size)ifnot header.size:break sample_size = header[0]+(header[1]<<8)+(header[2]<<16)+(header[3]<<24) tagcode = header[5]+(header[4]<<8) width = header[6]+(header[7]<<8) height = header[8]+(header[9]<<8)if header_size + width*height != sample_size:break image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))yield image, tagcode for file_name in os.listdir(gnt_dir):if file_name.endswith('.gnt'): file_path = os.path.join(gnt_dir, file_name)with open(file_path,'rb')as f:for image, tagcode in one_file(f):yield image, tagcode char_set =set()for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode =struct.pack('>H', tagcode).decode('gb2312') char_set.add(tagcode_unicode) char_list = list(char_set) char_dict = dict(zip(sorted(char_list), range(len(char_list))))print len(char_dict)import pickle f = open('char_dict','wb') pickle.dump(char_dict, f) f.close() train_counter =0 test_counter =0for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir): tagcode_unicode =struct.pack('>H', tagcode).decode('gb2312') im =Image.fromarray(image) dir_name ='../data/train/'+'%0.5d'%char_dict[tagcode_unicode]ifnot os.path.exists(dir_name): os.mkdir(dir_name) im.convert('RGB').save(dir_name+'/'+ str(train_counter)+'.png') train_counter +=1for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir): tagcode_unicode =struct.pack('>H', tagcode).decode('gb2312') im =Image.fromarray(image) dir_name ='../data/test/'+'%0.5d'%char_dict[tagcode_unicode]ifnot os.path.exists(dir_name): os.mkdir(dir_name) im.convert('RGB').save(dir_name+'/'+ str(test_counter)+'.png') test_counter +=1

處理好的資料,放到了雲盤,大家可以直接在我的雲盤來下載處理好的資料集HWDB1. 這裡說明下,char_dict是漢字和對應的數字label的記錄。

得到資料集後,就要考慮如何讀取了,一次用numpy讀入記憶體在很多小資料集上是可以行的,但是在稍微大點的資料集上記憶體就成了瓶頸,但是不要害怕,TensorFlow有自己的方法:

def batch_data(file_labels,sess, batch_size=128):
    image_list =[file_label[0]for file_label in file_labels]
    label_list =[int(file_label[1])for file_label in file_labels]print'tag2 {0}'.format(len(image_list))
    images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string)
    labels_tensor = tf.convert_to_tensor(label_list, dtype=tf.int64)
    input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor])

    labels = input_queue[1]
    images_content = tf.read_file(input_queue[0])# images = tf.image.decode_png(images_content, channels=1)
    images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32