TensorFlow與中文手寫漢字識別
Goal
本文目標是利用TensorFlow做一個簡單的影象分類器,在比較大的資料集上,儘可能高效地做影象相關處理,從Train,Validation到Inference,是一個比較基本的Example, 從一個基本的任務學習如果在TensorFlow下做高效地影象讀取,基本的影象處理,整個專案很簡單,但其中有一些trick,在實際專案當中有很大的好處, 比如絕對不要一次讀入所有的 的資料到記憶體(儘管在Mnist這類級別的例子上經常出現)…
最開始看到是這篇blog裡面的TensorFlow練習22: 手寫漢字識別, 但是這篇文章只用了140訓練與測試,試了下程式碼 很快,但是當擴充套件到所有的時,發現32g的記憶體都不夠用,這才注意到原文中都是用numpy,會先把所有的資料放入到記憶體,但這個不必須的,無論在MXNet還是TensorFlow中都是不必
須的,MXNet使用的是DataIter,會在程式執行的過程中非同步讀取資料,TensorFlow也是這樣的,TensorFlow封裝了高階的api,用來做資料的讀取,比如TFRecord,還有就是從filenames中讀取, 來非同步讀取檔案,然後做shuffle batch,再feed到模型的Graph中來做模型引數的更新。具體在tf如何做資料的讀取可以看看
這裡我會拿到所有的資料集來做訓練與測試,算作是對斗大的熊貓
上面那篇文章的一個擴充套件。
Batch Generate
資料集來自於中科院自動化研究所,感謝分享精神!!!具體下載:
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
解壓後發現是一些gnt檔案,然後用了斗大的熊貓
import os
import numpy as np
importstructfrom PIL importImage
data_dir ='../data'
train_data_dir = os.path.join(data_dir,'HWDB1.1trn_gnt')
test_data_dir = os.path.join(data_dir,'HWDB1.1tst_gnt' )def read_from_gnt_dir(gnt_dir=train_data_dir):def one_file(f):
header_size =10whileTrue:
header = np.fromfile(f, dtype='uint8', count=header_size)ifnot header.size:break
sample_size = header[0]+(header[1]<<8)+(header[2]<<16)+(header[3]<<24)
tagcode = header[5]+(header[4]<<8)
width = header[6]+(header[7]<<8)
height = header[8]+(header[9]<<8)if header_size + width*height != sample_size:break
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))yield image, tagcode
for file_name in os.listdir(gnt_dir):if file_name.endswith('.gnt'):
file_path = os.path.join(gnt_dir, file_name)with open(file_path,'rb')as f:for image, tagcode in one_file(f):yield image, tagcode
char_set =set()for _, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
tagcode_unicode =struct.pack('>H', tagcode).decode('gb2312')
char_set.add(tagcode_unicode)
char_list = list(char_set)
char_dict = dict(zip(sorted(char_list), range(len(char_list))))print len(char_dict)import pickle
f = open('char_dict','wb')
pickle.dump(char_dict, f)
f.close()
train_counter =0
test_counter =0for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
tagcode_unicode =struct.pack('>H', tagcode).decode('gb2312')
im =Image.fromarray(image)
dir_name ='../data/train/'+'%0.5d'%char_dict[tagcode_unicode]ifnot os.path.exists(dir_name):
os.mkdir(dir_name)
im.convert('RGB').save(dir_name+'/'+ str(train_counter)+'.png')
train_counter +=1for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
tagcode_unicode =struct.pack('>H', tagcode).decode('gb2312')
im =Image.fromarray(image)
dir_name ='../data/test/'+'%0.5d'%char_dict[tagcode_unicode]ifnot os.path.exists(dir_name):
os.mkdir(dir_name)
im.convert('RGB').save(dir_name+'/'+ str(test_counter)+'.png')
test_counter +=1
處理好的資料,放到了雲盤,大家可以直接在我的雲盤來下載處理好的資料集HWDB1. 這裡說明下,char_dict是漢字和對應的數字label的記錄。
得到資料集後,就要考慮如何讀取了,一次用numpy讀入記憶體在很多小資料集上是可以行的,但是在稍微大點的資料集上記憶體就成了瓶頸,但是不要害怕,TensorFlow有自己的方法:
def batch_data(file_labels,sess, batch_size=128):
image_list =[file_label[0]for file_label in file_labels]
label_list =[int(file_label[1])for file_label in file_labels]print'tag2 {0}'.format(len(image_list))
images_tensor = tf.convert_to_tensor(image_list, dtype=tf.string)
labels_tensor = tf.convert_to_tensor(label_list, dtype=tf.int64)
input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor])
labels = input_queue[1]
images_content = tf.read_file(input_queue[0])# images = tf.image.decode_png(images_content, channels=1)
images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32
相關推薦
TensorFlow與中文手寫漢字識別
Goal 本文目標是利用TensorFlow做一個簡單的影象分類器,在比較大的資料集上,儘可能高效地做影象相關處理,從Train,Validation到Inference,是一個比較基本的Example, 從一個基本的任務學習如果在TensorFlow下
聯機與離線 手寫漢字識別
1. 知識瞭解 1.1 漢字識別的兩類主流方法 Online recognition: 聯機識別,基於筆畫軌跡 Offline recognition: 離線識別, 基於影象 ( 聯機手寫漢字識別所處理的手寫文字是書寫者通過物理裝置 (如數字筆、 數字手寫板或者觸控式螢
TensorFlow手寫漢字識別
MNIST手寫數字資料集通常做為深度學習的練習資料集,這個資料集恐怕早已經被大家玩壞了。本帖就介紹一個和MNIST類似,同時又適合國人練習的資料集-手寫漢字資料集,然後訓練一個簡單的Deep Convolutional Network識別手寫漢字。 識別
OCR(離線手寫漢字識別與印刷漢字識別)
4 “最後的堡壘”——離線手寫漢字識別4.1 攻克堡壘待創新離線手寫漢字識別的用途是把手寫字元用字元閱讀器自動輸入計算機,常用於信函分揀、銀行支票識別和統計報表處理以及手寫文稿的自動輸入。從工作原理上說
Tensorflow實踐 mnist手寫數字識別
model 損失函數 兩層 最簡 sin test http gif bat minst數據集 tensorflow的文檔中就自帶了mnist手寫數字識別的例子,是一個很經典也比較簡單
Tensorflow之MNIST手寫數字識別:分類問題(1)
一、MNIST資料集讀取 one hot 獨熱編碼獨熱編碼是一種稀疏向量,其中:一個向量設為1,其他元素均設為0.獨熱編碼常用於表示擁有有限個可能值的字串或識別符號優點: 1、將離散特徵的取值擴充套件到了歐式空間,離散特徵的某個取值就對應歐式空間的某個點 2、機器學習演算法中,
Tensorflow之MNIST手寫數字識別:分類問題(2)
整體程式碼: #資料讀取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnis
Python(TensorFlow框架)實現手寫數字識別系統
手寫數字識別演算法的設計與實現 本文使用python基於TensorFlow設計手寫數字識別演算法,並程式設計實現GUI介面,構建手寫數字識別系統。這是本人的本科畢業論文課題,當然,這個也是機器學習的基本問題。本博文不會以論文的形式展現,而是以程式設計實戰
基於tensorflow的MNIST手寫數字識別(二)--入門篇
一、本文的意義 因為谷歌官方其實已經寫了MNIST入門和深入兩篇教程了,那我寫這些文章又是為什麼呢,只是抄襲?那倒並不是,更準確的說應該是筆記吧,然後用更通俗的語言來解釋,並且補充
【Tensorflow入門】手寫字型識別——卷積神經網路
慣例放結果,瞬間識別率就上99.29%了…… import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True) import tensorflow as tf sess =
Android+TensorFlow+CNN+MNIST 手寫數字識別實現
SkySeraph 2018 Overview 本文系“SkySeraph AI 實踐到理論系列”第一篇,咱以AI界的HelloWord 經典MNIST資料集為基礎,在Android平臺,基於TensorFlow,實現CNN的手寫數字識別。Code here~ Practice Env
【Tensorflow入門】手寫字型識別(MNIST)
轉載自: 地址 配置有困難的話可以直接下載: 地址 //當然照著這個教程配置很輕鬆的其實,完全可以不用浪費這1積分,攤手… MNIST機器學習入門 這個教程的目標讀者是對機器學習和TensorFlow都不太瞭解的新手。如果你已經瞭解MNIST和softmax
keras+卷積神經網路HWDB手寫漢字識別
寫在前面 HWDB手寫漢字資料集來自於中科院自動化研究所,下載地址: 原始碼 按照github上的提示操作: (1)解壓 unzip HWDB1.1trn_gnt.zip Archive: HWDB1.1trn_gnt.zip inflating:
基於tensorflow的MNIST手寫數字識別(三)--神經網路篇
想想還是要說點什麼 抱歉啊,第三篇姍姍來遲,確實是因為我懶,而不是忙什麼的,所以這次再加點料,以表示我的歉意。廢話不多說,我就直接開始講了。 加入神經網路的意義 前面也講到了,使用普通的訓練方法,也可以進行識別,但是識別的精度不夠高,
Tensorflow手寫數字識別之簡單神經網路分類與CNN分類效果對比
用Tensorflow進行深度學習和人工智慧具有開發簡單,建模速度快,準確度高的優點。作為學習影象識別分類的入門,手寫輸入數字識別是個很好的例子。 MNIST包中共有60000個手寫數字筆跡灰度影象作為訓練集,每張手寫數字筆跡圖片均已儲存為28*28畫素,同時還有一個la
Tensorflow - Tutorial (7) : 利用 RNN/LSTM 進行手寫數字識別
ddc htm net sets 手寫 n-2 align csdn global 1. 經常使用類 class tf.contrib.rnn.BasicLSTMCell BasicLSTMCell 是最簡單的一個LSTM類。沒有實現clippi
tensorflow 基礎學習五:MNIST手寫數字識別
truncate averages val flow one die correct 表示 data MNIST數據集介紹: from tensorflow.examples.tutorials.mnist import input_data # 載入MNIST數據集,
第二節,TensorFlow 使用前饋神經網絡實現手寫數字識別
com net config return pyplot dataset 運行 算法 但是 一 感知器 感知器學習筆記:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Percep
第三節,TensorFlow 使用CNN實現手寫數字識別
啟用 out min 灰度 HA 打破 gre 大量 gray 上一節,我們已經講解了使用全連接網絡實現手寫數字識別,其正確率大概能達到98%,著一節我們使用卷積神經網絡來實現手寫數字識別, 其準確率可以超過99%,程序主要包括以下幾塊內容 [1]: 導入數據,即測試集和
TensorFlow(九):卷積神經網絡實現手寫數字識別以及可視化
writer orm true 交叉 lar write 執行 one 界面 上代碼: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist =