1. 程式人生 > >數字標籤轉化為one-hot形式的tensor

數字標籤轉化為one-hot形式的tensor

剛剛入tensorflow的大坑,根據自己的理解今天來記錄一下如何將數字標籤轉化為one-hot形式。有錯誤的請諒解哈哈哈
what is form of one-hot ?即用包含0和1的tensor來表示數字標籤,數字1所在的索引值(從0開始)即為我們的數字標籤,例如我們有0-9的數字標籤,則標籤5所對應的one-hot形式為[0 , 0 , 0 , 0 , 0 ,1 , 0 , 0 , 0 , 0],因為1所在位置的索引值為5。
And how to transform the digital labels into one-hot tensor ? 我先把tensorflow 官網(tensorflow.google.cn)上的程式碼貼出來:

這裡寫圖片描述

現在看看每一個語句的作用:

1.batch_size = tf.size(labels)

我們需要對所有的數字標籤進行轉換,所以第一步,先通過上述語句獲取標籤的個數。比如,我們的標籤陣列為 labels = [1 , 4 , 6 , 8 , 3 , 7],則有
這裡寫圖片描述

2.labels_1 = tf.expand_dims(labels, 1)

該語句的是將labels的維度索引軸axis為1處(從0開始)插入1的尺寸。現在我們先看看labels的shape:
這裡寫圖片描述
那麼,經過語句2之後,有這裡寫圖片描述

如上所述
從一開始的shape[6](axis=0處有6個值)變成了shape[6,1](axis=0時有6個值,axis=1處插入1),不清楚的話可以試試:
labels_1 = tf.expand_dims(labels, 0)

補充: tf.expand_dims(input, axis=None)函式表示給定輸入tensor,在輸入shape的維度索引軸axis處插入為1的尺寸。 維度索引軸從0開始; 如果axis為負數,則從後向前計數。

3.indices =    tf.expand_dims(tf.range(0,batch_size,1),1)

tf.range(start, limit, delta=1)函式是用來生成tensor等差序列,序列在start到limit之間(包含start不包含limit),步長為dalta。
語句3先生成0-5的向量,再同語句2同樣的擴充套件維度:
這裡寫圖片描述

4.concated 
= tf.concat([indices, labels_1],1)

concated = tf.concat([indices, labels_1],concat_dim)表示在第concat_dim+1個維度疊加,例如,語句4的輸出為:

這裡寫圖片描述
其中,concat_dim=0(第一個維度)可以認為是行,concat_dim=1(第二個維度)為列,所以語句4 在列上疊加。更多維度的資訊自行查閱。

5.onehot_labels = tf.sparse_to_dense(concated, tf.stack([batch_size, 10]), 1.0, 0.0)

在語句5中,concated矩陣(如上圖)表示[0,1],[1,4]….[5,9]有值,tf.stack([batch_size, 10])表示輸出的one-hot矩陣,每一行表示一個標籤對應的ont-hot形式。1.0為one-hot的意義所在,即ont-hot矩陣中對應於concated矩陣有值的位置為1.0,然後0.0表示沒值的位置為0.0。
語句5 的輸出結果如圖:
這裡寫圖片描述