TensorFlow——訓練自己的資料(四)模型測試
阿新 • • 發佈:2019-01-06
獲取一張圖片
函式:def get_one_image(train):
- 輸入引數:train,訓練圖片的路徑
- 返回引數:image,從訓練圖片中隨機抽取一張圖片
n = len(train)
ind = np.random.randint(0, n)
img_dir = train[ind]
image = Image.open(img_dir)
plt.imshow(image)
image = image.resize([208, 208])
image = np.array(image)
return image
測試圖片
函式:def evaluate_one_image():
with tf.Graph().as_default():
BATCH_SIZE = 1
N_CLASSES = 2
image = tf.cast(image_array, tf.float32)
image = tf.image.per_image_standardization(image)
image = tf.reshape(image, [1, 208, 208, 3])
logit = model.inference(image, BATCH_SIZE, N_CLASSES)
logit = tf.nn.softmax(logit)
x = tf.placeholder(tf.float32, shape=[208 , 208, 3])
# you need to change the directories to yours.
logs_train_dir = 'D:/Study/Python/Projects/Cats_vs_Dogs/Logs/train'
saver = tf.train.Saver()
with tf.Session() as sess:
print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success, global_step is %s' % global_step)
else:
print('No checkpoint file found')
prediction = sess.run(logit, feed_dict={x: image_array})
max_index = np.argmax(prediction)
if max_index==0:
print('This is a cat with possibility %.6f' %prediction[:, 0])
else:
print('This is a dog with possibility %.6f' %prediction[:, 1])
訓練過程中按步驟測試圖片
在獲取檔案時,取出訓練圖片的20%作為測試資料
函式:def get_files(file_dir, ratio):
中修改
#所有的img和lab的list
all_image_list = temp[:, 0]
all_label_list = temp[:, 1]
#將所得List分為兩部分,一部分用來訓練tra,一部分用來測試val
#ratio是測試集的比例
n_sample = len(all_label_list)
n_val = math.ceil(n_sample*ratio) #測試樣本數
n_train = n_sample - n_val # 訓練樣本數
tra_images = all_image_list[0:n_train]
tra_labels = all_label_list[0:n_train]
tra_labels = [int(float(i)) for i in tra_labels]
val_images = all_image_list[n_train:-1]
val_labels = all_label_list[n_train:-1]
val_labels = [int(float(i)) for i in val_labels]
return tra_images,tra_labels,val_images,val_labels
函式:def get_files(file_dir, ratio):
中修改
獲取train和validation的batch
train_batch, train_label_batch = input_train_val_split.get_batch(train,
train_label,
IMG_W,
IMG_H,
BATCH_SIZE,
CAPACITY)
val_batch, val_label_batch = input_train_val_split.get_batch(val,
val_label,
IMG_W,
IMG_H,
BATCH_SIZE,
CAPACITY)
每隔200步,測試一批,同時記錄log
if step % 200 == 0 or (step + 1) == MAX_STEP:
val_images, val_labels = sess.run([val_batch, val_label_batch])
val_loss, val_acc = sess.run([loss, acc],
feed_dict={x:val_images, y_:val_labels})
print('** Step %d, val loss = %.2f, val accuracy = %.2f%% **' %(step, val_loss, val_acc*100.0))
summary_str = sess.run(summary_op)
val_writer.add_summary(summary_str, step)
結果
這張圖片是貓的概率為0.987972,所用模型的訓練步驟是6000步