1. 程式人生 > >tf.one_hot()函式簡介

tf.one_hot()函式簡介

tf.one_hot()函式是將input轉化為one-hot型別資料輸出,相當於將多個數值聯合放在一起作為多個相同型別的向量,可用於表示各自的概率分佈,通常用於分類任務中作為最後的FC層的輸出,有時翻譯成“獨熱”編碼。

tensorflow的help中相關說明如下:

one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
    Returns a one-hot tensor.

indices表示輸入的多個數值,通常是矩陣形式;depth表示輸出的尺寸。
由於one-hot型別資料長度為depth位,其中只用一位數字表示原輸入資料,這裡的on_value就是這個數字,預設值為1,one-hot資料的其他位用off_value表示,預設值為0。

tf.one_hot()函式規定輸入的元素indices從0開始,最大的元素值不能超過(depth - 1),因此能夠表示(depth + 1)個單位的輸入。若輸入的元素值超出範圍,輸出的編碼均為 [0, 0 … 0, 0]。

indices = 0 對應的輸出是[1, 0 … 0, 0], indices = 1 對應的輸出是[0, 1 … 0, 0], 依次類推,最大可能值的輸出是[0, 0 … 0, 1]。

程式碼示例如下:

import tensorflow as tf  

classes = 3
labels = tf.constant([0,1,2]) # 輸入的元素值最小為0,最大為2
output = tf.one_hot(labels,classes)

sess = tf.Session()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output = sess.run(output)
    print
("output of one-hot is : ",output) # ('output of one-hot is : ', array([[ 1., 0., 0.], # [ 0., 1., 0.], # [ 0., 0., 1.]], dtype=float32))