1. 程式人生 > >tfrecord數據集訓練驗證-貓狗大戰

tfrecord數據集訓練驗證-貓狗大戰

圖片大小 cat rac exc 兩個 bin span loss error:

#!/usr/bin/env python
# -*- coding:utf-8 -*-

from mk_tfrecord import *
#from model import *
from inception_v3 import *
import numpy as np
import os
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

def training():
    N_CLASSES = 2              # 分類數目
    IMG_W = 299                # 統一圖片大小,寬度
    IMG_H = 299                #
統一圖片大小,高度 BATCH_SIZE = 64 # 批次大小 MAX_STEP = 50000 # 叠代次數 LEARNING_RATE = 0.0001 # 學習率 min_after_dequeue = 1000 tfrecord_filename = /home/xieqi/project/cat_dog/train.tfrecords # 訓練數據集 logs_dir = /home/xieqi/project/cat_dog/log_v3 # 檢查點保存路徑 # 輸入--要生成的字符串的一維字符串張量,shuffle默認為True,輸出--字符串隊列
# 將字符串(例如文件名)輸出到輸入管道的隊列,不限制num_epoch。 filename_queue = tf.train.string_input_producer([tfrecord_filename], num_epochs=150) train_image, train_label = read_and_decode(filename_queue, image_W=IMG_W, image_H=IMG_H, batch_size=BATCH_SIZE,min_after_dequeue=min_after_dequeue) #
返回的為tensor train_labels = tf.one_hot(train_label, N_CLASSES) train_logits,_ = inception_v3(train_image,num_classes=N_CLASSES) train_loss = loss(train_logits, train_labels) # 損失函數 train_acc = accuracy(train_logits, train_labels) # 模型精確度 my_global_step = tf.Variable(0, name=global_step, trainable=False) # 全局步長 train_op = optimize(train_loss, LEARNING_RATE, my_global_step) #訓練模型 summary_op = tf.summary.merge_all() # 收集模型統計信息 init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())#初始化全局變量和局部變量 # 限制GPU使用率 # sess_config = tf.ConfigProto() # sess_config.gpu_options.per_process_gpu_memory_fraction = 0.70 # sess = tf.Session(config=sess_config) sess = tf.Session() # FileWriter類提供了一個機制來創建指定目錄的事件文件,並添加摘要和事件給它(異步更新,不影響訓練速度) train_writer = tf.summary.FileWriter(logs_dir, sess.graph) # 將Save類添加OPS保存和恢復變量和檢查點。對模型定期做checkpoint,通常用於模型恢復 saver = tf.train.Saver() sess.run(init_op) coord = tf.train.Coordinator() # 線程協調員, 實現一種簡單的機制來協調一組線程的終止 threads = tf.train.start_queue_runners(sess=sess, coord=coord) #啟動圖中收集的所有隊列, 開始填充隊列 try: for step in range(MAX_STEP): if coord.should_stop(): break image_batch, label_batch = sess.run([train_image, train_label]) #獲取一個批次的數據及標簽 sess.run(train_op) #每叠代100次計算一次loss和準確率 if step % 100 == 0: losses, acc = sess.run([train_loss, train_acc]) print(Step: %6d, loss: %.8f, accuracy: %.2f%% % (step, losses, acc)) summary_str = sess.run(summary_op) train_writer.add_summary(summary_str, step) if step % 1000 == 0 or step == MAX_STEP - 1: # 保存檢查點 checkpoint_path = os.path.join(logs_dir, model.ckpt) saver.save(sess, checkpoint_path, global_step=step) except tf.errors.OutOfRangeError: print(Done.) finally: coord.request_stop() coord.join(threads=threads) sess.close() # 測試檢查點 def eval(): N_CLASSES = 2 IMG_W = 299 IMG_H = 299 BATCH_SIZE = 1 MAX_STEP = 512 min_after_dequeue=0 test_dir = /home/xieqi/project/cat_dog/val.tfrecords #測試集數據 logs_dir = /home/xieqi/project/cat_dog/log_v3 # 檢查點目錄 false_pic_dir = /home/xieqi/project/cat_dog/false_pic/ #錯誤分類的圖片存儲地址 filename_queue = tf.train.string_input_producer([test_dir], num_epochs=1)#輸入要生成的字符串的一維字符張量,輸出字符串隊列,shuffle默認為True train_image, train_label = read_and_decode(filename_queue, image_W=IMG_W, image_H=IMG_H, batch_size=BATCH_SIZE,min_after_dequeue=min_after_dequeue) # 返回的為tensor train_labels = tf.one_hot(train_label, N_CLASSES) train_logits, _ = inception_v3(train_image, N_CLASSES) train_logits = tf.nn.softmax(train_logits) # 用softmax轉化為百分比數值 #計算準確率 correct_num = tf.placeholder(float) correct_pre = tf.div(correct_num, MAX_STEP) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init_op) # 載入檢查點 saver = tf.train.Saver() print(\n載入檢查點...) ckpt = tf.train.get_checkpoint_state(logs_dir) #通過checkpoint文件找到模型文件名,有兩個屬性:model_checkpoint_path最新的模型文件的文件名 # all_model_checkpoint_paths未被刪除的所有模型文件的文件名 if ckpt and ckpt.model_checkpoint_path: global_step = int(ckpt.model_checkpoint_path.split(/)[-1].split(-)[-1]) saver.restore(sess, ckpt.model_checkpoint_path) print(載入成功,global_step = %d\n % global_step) else: print(沒有找到檢查點) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: correct = 0 wrong = 0 dt_list = [] for step in range(MAX_STEP): if coord.should_stop(): break st = time.time() image, prediction, labels = sess.run([train_image, train_logits, train_labels]) dt = time.time() - st dt_list.append(dt) p_max_index = np.argmax(prediction) c_max_index = np.argmax(labels) if p_max_index == c_max_index: for i in range(BATCH_SIZE): correct += 1 else: for i in range(BATCH_SIZE): wrong += 1 cv2.imwrite(false_pic_dir + ture + str(labels) + predict + str(prediction) + .jpg, image[i]) accuray_rate = sess.run(correct_pre,feed_dict={correct_num: correct}) velocity = np.mean(dt_list) print(Total: %5d, correct: %5d, wrong: %5d, accuracy: %3.2f%%, each speed: %.4fs % (MAX_STEP, correct, wrong, accuray_rate * 100, velocity)) except tf.errors.OutOfRangeError: print(OutOfRange) finally: coord.request_stop() coord.join(threads=threads) sess.close() if __name__ == __main__: training() #eval()

tfrecord數據集訓練驗證-貓狗大戰