Tensorflow學習---tf.nn.embedding_lookup
tf.nn.embedding_lookup(params,ids, partition_strategy=’mod’, name=None, validate_indices=True,max_norm=None)
根據ids中的id,尋找params中的對應元素,可以理解為索引,所以ids中元素值不能超出params的第一維的維數值。
比如,ids=[1,3,5],則找出params中下標為1,3,5的向量組成一個矩陣返回。
引數說明:
params: 表示完整的embedding張量,或者除了第一維度之外具有相同形狀的P個張量的列表,表示經分割的嵌入張量。
ids: 一個型別為int32或int64的Tensor,包含要在params中查詢的id
下面是程式碼
#!/usr/bin/python
#encoding:utf-8
import tensorflow
as tf
encode_embeddings = tf.constant([[1,2,3,4,5],[6,7,8,9,0]]) #2*5
# input_ids中元素的值和encode_embeddings的第一維的維數有關,此例中為2維,input_ids只能是[0,2),也就是0和1
input_ids =tf.constant([[1,1,0],[1,0,1],[1,0,
1],[0,1,
1]]) #4*3
session = tf.Session()
with session.as_default():
#
results =tf.nn.embedding_lookup(encode_embeddings,input_ids)
print(results.eval())# tf.eval()函式用於顯示張量tensor的值,但需要放在with session.as_default()中才能使用。
'''結果值
[[[6 7 8 9 0]
[6 7 8 9 0]
[1 2 3 4 5]]
[[6 7 8 9 0]
[1 2 3 4 5]
[6 7 8 9 0]]
[[6 7 8 9 0]
[1 2 3 4 5]
[6 7 8 9 0]]
[[1 2 3 4 5]
[6 7 8 9 0]
[6 7 8 9 0]]]'''