tensorflow神經網路結構視覺化
阿新 • • 發佈:2018-12-14
藉助 tensorboard 實現tensorflow中定義的深度神經網路視覺化。
在程式中實現網路視覺化,只需要在載入網路之後,加上這一句:
summary_writer = tf.summary.FileWriter('./log/', sess.graph)
上邊的 sess.graph 就是定義的網路結構了,使用summary.FileWriter 方法儲存到本地。或者:
summary_writer = tf.summary.FileWriter('./log/', tf.get_default_graph())
完整程式碼:
# -*- coding: utf-8 -*- import tensorflow as tf # 影象大小 IMAGE_HEIGHT = 256 IMAGE_WIDTH = 256 MAX_CAPTCHA = 4 CHAR_SET_LEN = 10 input = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT , IMAGE_WIDTH, 1]) # 定義CNN def crack_captcha_cnn(x=input, w_alpha=0.01, b_alpha=0.1): # conv layer w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 32])) b_c1 = tf.Variable(b_alpha * tf.random_normal([32])) conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1)) conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') # Fully connected layer w_d = tf.Variable(w_alpha * tf.random_normal([8 * 20 * 64, 1024])) b_d = tf.Variable(b_alpha * tf.random_normal([1024])) dense = tf.reshape(conv1, [-1, w_d.get_shape().as_list()[0]]) dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d)) w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN])) b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN])) out = tf.add(tf.matmul(dense, w_out), b_out) return out # 載入網路 evaluate_net = crack_captcha_cnn() with tf.Session() as sess: # 網路結構寫入 summary_writer = tf.summary.FileWriter('./log/', sess.graph) # summary_writer = tf.summary.FileWriter('./log/', tf.get_default_graph()) print('OK')
執行完成之後在程式目錄下生成log資料夾,儲存了網路資訊,使用tensorboard執行:
tensorboard --logdir=log
在瀏覽器輸入返回的網址,就可以看到網路結構了: