1. 程式人生 > >關於tf.one_hot()需知其所以然

關於tf.one_hot()需知其所以然

該命令一般用於cost function

 

import tensorflow as tf
N_CLASSES = 5
labels = [1,4,0,2,3,0,1,1]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    A = sess.run(tf.one_hot(labels,N_CLASSES))
    print('one_hot:')
    print(A)
    
cost = tf.losses.softmax_cross_entropy(A,logits,weight = 1)

以上程式碼最後一行不執行,得到結果如下

結合最後一行程式碼就好理解了。

最後一層輸出的維數=N,此例=5。

y_{i}=\frac{exp(logits_{i})}{\sum _{n=1}^{N}exp(logits_{n})}

H = -\sum_{i=1}^{M} label_{i}logy_{i}

H為交叉熵,M為batch無誤,此例M=8。