1. 程式人生 > >小專案——訓練畫風資料集

小專案——訓練畫風資料集

承接上篇處理好資料之後開始定義網路

tf.placeholder定義形參

datas_placeholder = tf.placeholder(tf.float32, [None,32,32,3])
labels_placeholder = tf.placeholder(tf.int32, [None])
dropout_placeholder = tf.placeholder(tf.float32) 

tf.layers定義網路結構

conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.
relu) pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2]) conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu) pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2]) flatten = tf.layers.flatten(pool1) fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu) dropout_fc = tf.layers.dropout(fc, dropout_placeholder)
logits = tf.layers.dense(dropout_fc, num_classes) predicted_labels = tf.argmax(logits, 1)

定義損失函式和優化器

losses = tf.nn.softmax_cross_entropy_with_logits_v2(
    labels=tf.one_hot(labels_placeholder, num_classes),
    logits=logits
)
mean_loss = tf.reduce_mean(losses)
optimizer = tf.train.AdamOptimizer(
learning_rate=1e-2).minimize(losses)

儲存模型+(訓練or測試)

saver = tf.train.Saver()
with tf.Session() as sess:
    if train:
        print("訓練")
        sess.run(tf.global_variables_initializer())
        train_feed_dict = {
            datas_placeholder: datas,
            labels_placeholder: labels,
            dropout_placeholder: 0.1
        }
        for step in range(200):
            _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
            if step % 10 == 0:
                print("step = {}\tmean loss = {}".format(step, mean_loss_val))
        saver.save(sess, model_path)
        print("訓練結束,儲存模型到{}".format(model_path))
    else:
        print("測試")
        saver.restore(sess, model_path)
        print("從{}載入模型".format(model_path))
        label_name_dict = {
            0: "塗鴉",
            1: "油畫",
            2: "素描"
        }
        test_feed_dict = {
            datas_placeholder: datas,
            labels_placeholder: labels,
            dropout_placeholder: 0
        }
        predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
        for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
            real_label_name = label_name_dict[real_label]
            predicted_label_name = label_name_dict[predicted_label]
            print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))

訓練效果
在這裡插入圖片描述還不錯,至此畫風識別小專案基本完成(手動撒花)

持續完善