1. 程式人生 > >tensorflow神經網路結構視覺化

tensorflow神經網路結構視覺化

藉助 tensorboard 實現tensorflow中定義的深度神經網路視覺化。


在程式中實現網路視覺化,只需要在載入網路之後,加上這一句:

summary_writer = tf.summary.FileWriter('./log/', sess.graph)

上邊的 sess.graph 就是定義的網路結構了,使用summary.FileWriter 方法儲存到本地。或者:

summary_writer = tf.summary.FileWriter('./log/', tf.get_default_graph())

完整程式碼:

# -*- coding: utf-8 -*-
import tensorflow as tf


# 影象大小
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
MAX_CAPTCHA = 4
CHAR_SET_LEN = 10

input = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT , IMAGE_WIDTH, 1])

# 定義CNN
def crack_captcha_cnn(x=input, w_alpha=0.01, b_alpha=0.1):
    # conv layer
    w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 32]))
    b_c1 = tf.Variable(b_alpha * tf.random_normal([32]))
    conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

    # Fully connected layer
    w_d = tf.Variable(w_alpha * tf.random_normal([8 * 20 * 64, 1024]))
    b_d = tf.Variable(b_alpha * tf.random_normal([1024]))
    dense = tf.reshape(conv1, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))

    w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
    b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
    out = tf.add(tf.matmul(dense, w_out), b_out)
    return out

# 載入網路
evaluate_net = crack_captcha_cnn()

with tf.Session() as sess:
    # 網路結構寫入
    summary_writer = tf.summary.FileWriter('./log/', sess.graph)
    # summary_writer = tf.summary.FileWriter('./log/', tf.get_default_graph())

print('OK')

執行完成之後在程式目錄下生成log資料夾,儲存了網路資訊,使用tensorboard執行:

tensorboard --logdir=log

在瀏覽器輸入返回的網址,就可以看到網路結構了: