Tensorflow例項2:將影象和標籤資料(*.csv)轉化成tfrecords檔案,以便後續直接讀取tfrecords檔案進行圖片驗證碼識別訓練
阿新 • • 發佈:2018-12-20
由於多張影象和標籤值不在一起,現在此方法是把captcha_dir = "../data/GenPics/"
此路徑下的圖片與此路徑下的.csv
檔案合併起來,通過writer = tf.python_io.TFRecordWriter(path="./data/captcha.tfrecords")
將資料以tfrecords格式寫入到本地中,為了以後進行驗證碼圖片訓練做好準備。
具體操作步驟如下:
# -*- coding=utf-8 -*- import tensorflow as tf import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2' # 只顯示 warning 和 Error FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("tfrecords_dir", "./data/captcha.tfrecords", "驗證碼tfrecords檔案") tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "驗證碼圖片路徑") tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "驗證碼字元的種類") def get_captcha_image(captcha_dir): """ 獲取驗證碼圖片資料 :param captcha_dir: 驗證碼圖片路徑 :return: image """ # 構造檔名 filename = [] for i in range(6000): string = str(i) + ".jpg" filename.append(string) # 構造路徑+檔案 # file_list = [os.path.join(FLAGS.captcha_dir, file) for file in filename] file_list = [os.path.join(captcha_dir, file) for file in filename] # 構造檔案佇列 file_queue = tf.train.string_input_producer(file_list, shuffle=False) # 構造閱讀器 reader = tf.WholeFileReader() # 讀取圖片資料內容 key, value = reader.read(file_queue) # 解碼圖片資料 image = tf.image.decode_jpeg(value) image.set_shape([20, 80, 3]) # 批量處理資料 [6000, 20, 80,3] image_batch = tf.train.batch([image], batch_size=6000, num_threads=1, capacity=6000) return image_batch def get_captcha_label(captcha_dir): """ 讀取驗證碼圖片標籤資料 :param captcha_dir: 驗證碼標籤路徑 :return: label """ # 構造標籤資料檔案路徑 captcha_dir = captcha_dir + "labels.csv" # 構造檔案佇列 file_queue = tf.train.string_input_producer([captcha_dir], shuffle=False) # 構造閱讀器 reader = tf.TextLineReader() # 讀取excel的label資料內容 key, value = reader.read(file_queue) # 解碼csv資料 # records:指定矩陣格式以及資料型別 # [1]中的1 用於指定資料型別,比如矩陣中如果有小數,則為float,[1]應該變為[1.0]。 records = [[1], ["None"]] number, label = tf.decode_csv(value, record_defaults=records) # 批處理資料 label_batch = tf.train.batch([label], batch_size=6000, num_threads=1, capacity=6000) return label_batch def dealwuthlabel(label_str): """ :param label_str: :return: """ # 驗證碼字串的種類 letter = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" # 構建字元索引 {0:'A', 1:'B', ...} num_letter = dict(enumerate(list(letter))) # 鍵值對反轉{'A':0, 'B':1, ...} letter_num = dict(zip(num_letter.values(), num_letter.keys())) print(letter_num) # 構建標籤到列表 array = [] # 給標籤資料進行處理 [[b'NZPP'], ...] for string in label_str: letter_list = [] # [1, 2, 3, 4] # 修改編碼,b'NZPP'到字串,並且迴圈找到每張驗證碼的字元對應的數字標記 for letter in string.decode("utf-8"): letter_list.append(letter_num[letter]) array.append(letter_list) # [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], ...] # print(array) # 將array轉換成Tensor型別 label = tf.constant(array) return label def write_to_tfrecords(image_batch, label_batch): """ 將圖片內容和標籤寫入到tfrecords檔案當中 :param image_batch: 特徵值 :param label_batch: 標籤值 :return: None """ # 轉換型別 label_batch = tf.cast(label_batch, tf.uint8) print(label_batch) # 建立TFRecords 儲存器 # writer = tf.python_io.TFRecordWriter(path=FLAGS.tfrecords_dir) writer = tf.python_io.TFRecordWriter(path="./data/captcha.tfrecords") # 迴圈將每一個圖片資料構造example協議快,序列化後寫入 for i in range(label_batch.shape[0]): # 取出第i個圖片資料,轉換為相應型別,圖片的特徵值要轉換為字串形式 image_string = image_batch[i].eval().tostring() # 標籤值,轉換成整型 label_string = label_batch[i].eval().tostring() # 構造協議塊 example = tf.train.Example(features=tf.train.Features(feature={ "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])), "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string])), })) writer.write(example.SerializeToString()) # 關閉檔案 writer.close() return None if __name__ == '__main__': # 資料路徑 captcha_dir = "../data/GenPics/" # 獲取驗證碼檔案當中的圖片 image_batch = get_captcha_image(captcha_dir) # 獲取驗證碼檔案當中的標籤資料 label = get_captcha_label(captcha_dir) print(image_batch, label) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # [b'NZPP' b'WKHK' b'WPSJ' ... b'FVQJ' b'BQYA' b'BCHR'] label_str = sess.run(label) print(label_str) # 處理字串標籤 轉變為數字張量 label_batch = dealwuthlabel(label_str) # 將圖片資料和內容寫入到tfrecords檔案中 write_to_tfrecords(image_batch, label_batch) coord.request_stop() coord.join(threads)