Tensorflow——————API 使用方法記錄
阿新 • • 發佈:2018-12-10
返回一個onehot tensor
tf.one_hot(
indices,
depth,
on_value=None,
off_value=None,
axis=None,
dtype=None,
name=None
)
-
indices
: tensor型別值得索引. -
depth
: 代表one-hot得深度,可以理解為對應得類別數. -
on_value
: 一個標量值,索引位置上得值. (default: 1). -
off_value
: 索引之外的值. (default: 0) -
axis
: -
dtype
: The data type of the output tensor. -
name
: op節點名稱 (optional).
一個N維的輸入,則輸出維度為N+1,意思就是:如果輸入是一個標量,則輸出為一個固定深度的向量;如果輸入是一個 固定深度的向量,則輸出是一個二維tensor。上述兩個輸入,又分兩個情況(axis=-1或者0)。
# 輸入為標量 [features, depth] if axis == -1 [depth, features] if axis == 0 # 輸入為固定長度的向量 [batch, features x depth] if axis == -1 [batch, depth x features] if axis == 1 [depth, batch x features] if axis == 0
舉例如下:
# 標量 indices = 3 depth = 4 a = tf.one_hot(indices, depth, axis=-1) sess = tf.InteractiveSession() print(sess.run(a)) # 輸出 # axis=-1或者0,結果都是一樣的,因為就一個維度,怎麼取都是那一個。。。。。 [0. 0. 0. 1.] -------------------------------------------------------------------------------- # 輸入為固定長度向量 indices = [0, 1, 3, 2] depth = 5 a = tf.one_hot(indices, depth, axis=-1) sess = tf.InteractiveSession() print(sess.run(a)) # axis=-1時 [[1. 0. 0. 0. 0.] [0. 1. 0. 0. 0.] [0. 0. 0. 1. 0.] [0. 0. 1. 0. 0.]] # axis=0時 [[1. 0. 0. 0.] [0. 1. 0. 0.] [0. 0. 0. 1.] [0. 0. 1. 0.] [0. 0. 0. 0.]]
當輸出為[batch, features]時(訓練的時候都是批次的),輸出維度應該是這樣的:
[batch, features, depth] if axis == -1
[batch, depth, features] if axis == 1
[depth, batch, features] if axis == 0
舉例:
# 對這種輸入的one-hot,我是一般不會遇到,平時模型訓練,label都是一維的
indices = [[0, 2], [1, -1], [2, 3]]
depth = 5
a = tf.one_hot(indices, depth, axis=-1)
sess = tf.InteractiveSession()
print(sess.run(a))
# 輸出
# aixs=-1
[[[1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0.]]
[[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
[[0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0.]]]
# axis=0
[[[1. 0.]
[0. 0.]
[0. 0.]]
[[0. 0.]
[1. 0.]
[0. 0.]]
[[0. 1.]
[0. 0.]
[1. 0.]]
[[0. 0.]
[0. 0.]
[0. 1.]]
[[0. 0.]
[0. 0.]
[0. 0.]]]
持續更新·············