Tensorflow學習筆記-輸入資料處理框架
阿新 • • 發佈:2019-02-18
對應的程式碼流程如下:
# 建立檔案列表,並通過檔案列表來建立檔案佇列。在呼叫輸入資料處理流程前,需要統一
# 所有的原始資料格式,並將它們儲存到TFRecord檔案中
# match_filenames_once 獲取符合正則表示式的所有檔案
files = tf.train.match_filenames_once('path/to/file-*-*')
# 將檔案列表生成檔案佇列
filename_queue = tf.train.string_input_producer(files,shuffle=True)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# image:儲存影象中的原始資料
# label該樣本所對應的標籤
# width,height,channel
features = tf.parse_single_example(serialized_example,features={
'image' : tf.FixedLenFeature([],tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'heigth': tf.FixedLenFeature([], tf.int64),
'channel': tf.FixedLenFeature([], tf.int64)
})
image, label = features['image'], features['label' ]
width, height = features['width'], features['height']
channel = features['channel']
# 將原始影象資料解析出畫素矩陣,並根據影象尺寸還原糖影象。
decode_image = tf.decode_raw(image)
decode_image.set_shape([width,height,channel])
# 神經網路的輸入大小
image_size = 299
# 對影象進行預處理操作,比對亮度、對比度、隨機裁剪等操作
distorted_image = propocess_train(decode_image,image_size,None)
# shuffle_batch中的引數
min_after_dequeue = 1000
batch_size = 100
capacity = min_after_dequeue + 3*batch_size
image_batch,label_batch = tf.train.shuffle_batch([distorted_image,label],
batch_size=batch_size,capacity=capacity,
min_after_dequeue=min_after_dequeue)
logit = inference(image_batch)
loss = cal_loss(logit,label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
with tf.Session() as sess:
# 變數初始化
tf.global_variables_initializer().run()
# 執行緒初始化和啟動
coord = tf.train.Coordinator()
theads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(STEPS):
sess.run(train_step)
# 停止所有執行緒
coord.request_stop()
coord.join(threads)