深入Tensorflow Embedding,解決Tensorboard無法顯示PROJECTOR問題
Motivation
Tensorboard的PROJECTOR有著很美好的功能,它能夠幫助我們檢視網路中資料在二維、三維下的效果,但是一些Embedding示例過於簡單,並沒有網路的訓練過程,所以將這些示例的用法用於真實模型訓練時,會出現各種問題,導致Tensorboard中的PROJECTPR版塊總是空白,顯示不出寫入的Embedding,歸根到底還是對Tensorboard的Embedding機制不太瞭解,經過不停的摸索,總算初步掌握了Embedding的用法,以下將帶大家深入瞭解Tensorflow Embedding.
Embedding
Tensorflow的PROJECTOR包含三個過程:1.獲得資料 2.獲得資料標籤 3.獲得資料影象。以下將通過一個mnist的簡單示例來解釋(程式碼彙總請見最下面),先上一個效果圖:
獲得資料
首先,我們要知道Tensorboard中顯示的PROJECTOR資料是從saver()儲存的張量資料中取出,所以在Embedding機制中通過一個tf.Variable變數(即下面的embedding_var)來儲存需要投影的影象或者資料,為什麼用tf.Variable呢?因為saver()主要儲存的是網路的權值,而這些權值基本上都是由tf.Variable儲存的,而像佔位符tf.placeholder則不會儲存,因為如果需要儲存的話,那麼所佔記憶體就太大了,所以在Embedding機制中通過一個tf.Variable變數來儲存需要投影的影象或者資料。-------而既然是tf.Variable,那麼就需要用session初始化,所以embedding_var必須在sess.run(tf.global_variables_initializer())和saver.save()之前。對應程式碼如下:
import matplotlib.pyplot as plt import tensorflow as tf import numpy as np import os from tensorflow.contrib.tensorboard.plugins import projector from tensorflow.examples.tutorials.mnist import input_data LOG_DIR = 'minimalsample' NAME_TO_VISUALISE_VARIABLE = "mnistembedding" TO_EMBED_COUNT = 500 mnist = input_data.read_data_sets("MNIST_data/", one_hot=False) batch_xs, batch_ys = mnist.train.next_batch(TO_EMBED_COUNT) embedding_var = tf.Variable(batch_xs, name=NAME_TO_VISUALISE_VARIABLE) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"), 1)
以上就算是Embedding的最主要的部分了,通過Variable儲存資料,然後Tensorboard就可以從saver()儲存的ckpt檔案中獲得資料了。另外我們還要註冊一個Projector,這樣Tensorboard才顯示Projector,而註冊Projector則需要Embedding的資訊,比如嵌入的變數名、tsv檔案路徑(下有解釋)、sprite影象路徑(下有解釋)等。然後將寫入tf.summary.FileWriter中就相當於是註冊了projector了。程式碼如下:
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
# Specify where you find the metadata
embedding.metadata_path = 'metadata.tsv' #'metadata.tsv'
# Specify where you find the sprite (we will create this later)
embedding.sprite.image_path = 'mnistdigits.png' #'mnistdigits.png'
embedding.sprite.single_image_dim.extend([28,28])
# Say that you want to visualise the embeddings
projector.visualize_embeddings(summary_writer, config)
這樣Tensorboard就會顯示Embedding的投影PROJECTOR了。
獲得資料標籤
Tensorboard要有顯示,這裡涉及到兩個檔案:1.tsv檔案 2.sprite影象。 在以下示例中是metadata.tsv和mnistdigits.png。
- metadata.tsv:用於存放embedding_var中各資料的標籤,以便在tensorboard中以不同顏色標註這些資料點。
- mnistdigits.png:用於存放資料的影象表示。
其中metadata.tsv生成如下:
with open(path_for_mnist_metadata,'w') as f:
f.write("Index\tLabel\n")
for index,label in enumerate(batch_ys):
f.write("%d\t%d\n" % (index,label))
可以很容易的看出metadata包含了資料的索引值和對應的標籤,而這裡的索引值和標籤對應的embedding_var中資料的索引值和標籤。
到這的時候,tensorboard上就會顯示出資料點出來了,類似於:
其中不同顏色表示不同類別。
獲取資料影象
如果需要把這些資料點用影象進行表示,就需要用到mnistdigits.png,這裡生成mnistdigits.png用到了三個函式:
def create_sprite_image(images):
"""Returns a sprite image consisting of images passed as argument. Images should be count x width x height"""
if isinstance(images, list):
images = np.array(images)
img_h = images.shape[1]
img_w = images.shape[2]
n_plots = int(np.ceil(np.sqrt(images.shape[0])))
spriteimage = np.ones((img_h * n_plots ,img_w * n_plots ))
for i in range(n_plots):
for j in range(n_plots):
this_filter = i * n_plots + j
if this_filter < images.shape[0]:
this_img = images[this_filter]
spriteimage[i * img_h:(i + 1) * img_h,
j * img_w:(j + 1) * img_w] = this_img
return spriteimage
def vector_to_matrix_mnist(mnist_digits):
"""Reshapes normal mnist digit (batch,28*28) to matrix (batch,28,28)"""
return np.reshape(mnist_digits,(-1,28,28))
def invert_grayscale(mnist_digits):
""" Makes black white, and white black """
return 1-mnist_digits
to_visualise = batch_xs
to_visualise = vector_to_matrix_mnist(to_visualise)
to_visualise = invert_grayscale(to_visualise)
sprite_image = create_sprite_image(to_visualise)
其中vector_to_matrix_mnist用於將向量轉為矩陣形式(影象表示),invert_grayscale則用於給影象新增背景色,一般為以上的用法,即背景色為白色。create_sprite_image則是把經過vector_to_matrix_mnist、invert_grayscale處理的embedding_var中的資料變為一個影象牆,如下:
這裡有500個手寫體影象,這裡的500與embedding_var資料的行數對應相同。到此為止,Tensorboard就可以根據索引值和每個資料點的大小[28*28]來取資料點對應的影象貼到PROJECTOR上對應的資料點了。大功告成!效果如下:
參考
https://www.pinchofintelligence.com/simple-introduction-to-tensorboard-embedding-visualisation/
程式碼彙總
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.examples.tutorials.mnist import input_data
LOG_DIR = 'minimalsample'
NAME_TO_VISUALISE_VARIABLE = "mnistembedding"
TO_EMBED_COUNT = 500
path_for_mnist_sprites = os.path.join(LOG_DIR,'mnistdigits.png')
path_for_mnist_metadata = os.path.join(LOG_DIR,'metadata.tsv')
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
batch_xs, batch_ys = mnist.train.next_batch(TO_EMBED_COUNT)
embedding_var = tf.Variable(batch_xs, name=NAME_TO_VISUALISE_VARIABLE)
summary_writer = tf.summary.FileWriter(LOG_DIR)
def create_sprite_image(images):
"""Returns a sprite image consisting of images passed as argument. Images should be count x width x height"""
if isinstance(images, list):
images = np.array(images)
img_h = images.shape[1]
img_w = images.shape[2]
n_plots = int(np.ceil(np.sqrt(images.shape[0])))
spriteimage = np.ones((img_h * n_plots ,img_w * n_plots ))
for i in range(n_plots):
for j in range(n_plots):
this_filter = i * n_plots + j
if this_filter < images.shape[0]:
this_img = images[this_filter]
spriteimage[i * img_h:(i + 1) * img_h,
j * img_w:(j + 1) * img_w] = this_img
return spriteimage
def vector_to_matrix_mnist(mnist_digits):
"""Reshapes normal mnist digit (batch,28*28) to matrix (batch,28,28)"""
return np.reshape(mnist_digits,(-1,28,28))
def invert_grayscale(mnist_digits):
""" Makes black white, and white black """
return 1-mnist_digits
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
# Specify where you find the metadata
embedding.metadata_path = 'metadata.tsv' #'metadata.tsv'
# Specify where you find the sprite (we will create this later)
embedding.sprite.image_path = 'mnistdigits.png' #'mnistdigits.png'
embedding.sprite.single_image_dim.extend([28,28])
# Say that you want to visualise the embeddings
projector.visualize_embeddings(summary_writer, config)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"), 1)
to_visualise = batch_xs
to_visualise = vector_to_matrix_mnist(to_visualise)
to_visualise = invert_grayscale(to_visualise)
sprite_image = create_sprite_image(to_visualise)
plt.imsave(path_for_mnist_sprites,sprite_image,cmap='gray')
plt.imshow(sprite_image,cmap='gray')
with open(path_for_mnist_metadata,'w') as f:
f.write("Index\tLabel\n")
for index,label in enumerate(batch_ys):
f.write("%d\t%d\n" % (index,label))