1. 程式人生 > >關於Tensorflow批量資料的輸入

關於Tensorflow批量資料的輸入

關於Tensorflow下的批量資料的輸入處理:
1.Tensor TFrecords格式
2.h5py的庫的陣列方法

在tensorflow的框架下寫CNN程式碼,我在書寫過程中,感覺不是框架內容難寫, 更多的是我在對影象的預處理和輸入這部分花了很多精神。

使用了兩種方法:
方法一:
Tensor 以Tfrecords的格式儲存資料,如果對資料進行標籤,可以同時做到資料打標籤。
①建立TFrecords檔案

orig_image = '/home/images/train_image/'
gen_image = '/home/images/image_train.tfrecords'
def create_record(): writer = tf.python_io.TFRecordWriter(gen_image) class_path = orig_image for img_name in os.listdir(class_path): #讀取每一幅影象 img_path = class_path + img_name img = Image.open(img_path) #讀取影象 #img = img.resize((256, 256)) #設定圖片大小, 在這裡可以對影象進行處理 img_raw = img.tobytes() #將圖片轉化為原聲bytes
example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), #打標籤 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#儲存資料 })) writer.write(example.SerializeToString()) writer.close()

②讀取TFrecords檔案

def read_and_decode(filename):
    #建立檔案佇列,不限讀取的資料
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
            serialized_example,
            features={
                    'label': tf.FixedLenFeature([], tf.int64),
                    'img_raw': tf.FixedLenFeature([], tf.string)})
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)  #tf.float32
    img = tf.image.convert_image_dtype(img, dtype=tf.float32)
    img = tf.reshape(img, [256, 256, 1])
    label = tf.cast(label, tf.int32)
    return img, label

③批量讀取資料,使用tf.train.batch

min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
num_samples= len(os.listdir(orig_image))
create_record()
img, label = read_and_decode(gen_image)
total_batch = int(num_samples/batch_size)
image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size,
                                           num_threads=32, capacity=capacity)  
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(total_batch):
         cur_image_batch, cur_label_batch  = sess.run([image_batch, label_batch])
    coord.request_stop()
    coord.join(threads)

方法二:
使用h5py就是使用陣列的格式來儲存資料
這個方法比較好,在CNN的過程中,會使用到多個數據類儲存,比較好用, 比如一個數據進行了兩種以上的變化,並且分類儲存,我認為這個方法會比較好用。

import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.interpolate import griddata
from skimage import img_as_float
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
class_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_gray_0_1/'
for img_name in os.listdir(class_path):
    img_path = class_path + img_name
    img = io.imread(img_path)
    m1 = img_as_float(img)
    m2, m3 = sample_inter1(m1) #一個數據處理的函式
    m1 = m1.reshape([256, 256, 1])
    m2 = m2.reshape([256, 256, 1])
    m3 = m3.reshape([256, 256, 1])
    orig_image.append(m1)
    sample_near.append(m2)
    sample_line.append(m3)

arrorig_image = np.asarray(orig_image) # [?, 256, 256, 1]
arrlsample_near = np.asarray(sample_near) # [?, 256, 256, 1]  
arrlsample_line = np.asarray(sample_line) # [?, 256, 256, 1] 

save_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_sample/train.h5'
def make_data(path):
    with h5py.File(save_path, 'w') as hf:
         hf.create_dataset('orig_image', data=arrorig_image)
         hf.create_dataset('sample_near', data=arrlsample_near)
         hf.create_dataset('sample_line', data=arrlsample_line)

def read_data(path):
    with h5py.File(path, 'r') as hf:
         orig_image = np.array(hf.get('orig_image')) #一定要對清楚上邊的標籤名orig_image;
         sample_near = np.array(hf.get('sample_near'))
         sample_line = np.array(hf.get('sample_line'))
    return orig_image, sample_near, sample_line
make_data(save_path)
orig_image1, sample_near1, sample_line1 = read_data(save_path)
total_number = len(orig_image1)
batch_size = 20
batch_index = total_number/batch_size
for i in range(batch_index):
    batch_orig = orig_image1[i*batch_size:(i+1)*batch_size]
    batch_sample_near = sample_near1[i*batch_size:(i+1)*batch_size]
    batch_sample_line = sample_line1[i*batch_size:(i+1)*batch_size]

在使用h5py的時候,生成的檔案巨大的時候,讀取資料顯示錯誤:ioerror: unable to open file (bad object header version number)
基本就是這個生成的檔案不能使用,適當的減少儲存的資料,即可。