小專案——訓練畫風資料集
阿新 • • 發佈:2018-12-16
承接上篇處理好資料之後開始定義網路
tf.placeholder定義形參
datas_placeholder = tf.placeholder(tf.float32, [None,32,32,3])
labels_placeholder = tf.placeholder(tf.int32, [None])
dropout_placeholder = tf.placeholder(tf.float32)
tf.layers定義網路結構
conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn. relu)
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
flatten = tf.layers.flatten(pool1)
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
dropout_fc = tf.layers.dropout(fc, dropout_placeholder)
logits = tf.layers.dense(dropout_fc, num_classes)
predicted_labels = tf.argmax(logits, 1)
定義損失函式和優化器
losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.one_hot(labels_placeholder, num_classes),
logits=logits
)
mean_loss = tf.reduce_mean(losses)
optimizer = tf.train.AdamOptimizer( learning_rate=1e-2).minimize(losses)
儲存模型+(訓練or測試)
saver = tf.train.Saver()
with tf.Session() as sess:
if train:
print("訓練")
sess.run(tf.global_variables_initializer())
train_feed_dict = {
datas_placeholder: datas,
labels_placeholder: labels,
dropout_placeholder: 0.1
}
for step in range(200):
_, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
if step % 10 == 0:
print("step = {}\tmean loss = {}".format(step, mean_loss_val))
saver.save(sess, model_path)
print("訓練結束,儲存模型到{}".format(model_path))
else:
print("測試")
saver.restore(sess, model_path)
print("從{}載入模型".format(model_path))
label_name_dict = {
0: "塗鴉",
1: "油畫",
2: "素描"
}
test_feed_dict = {
datas_placeholder: datas,
labels_placeholder: labels,
dropout_placeholder: 0
}
predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
real_label_name = label_name_dict[real_label]
predicted_label_name = label_name_dict[predicted_label]
print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))
訓練效果
還不錯,至此畫風識別小專案基本完成(手動撒花)