tf.nn.embedding_lookup函式用法
阿新 • • 發佈:2020-10-11
tf.nn.embedding_lookup函式主要用於選取一個張量或者數組裡對應元素的值,即輸入一個索引,輸出該索引對應的值。
先看看引數
def embedding_lookup(
params,
ids,
partition_strategy="mod",
name=None,
validate_indices=True, # pylint: disable=unused-argument
max_norm=None):
其中params跟ids比較重要,params即為資料來源,ids即索引。
詳見例項:
import tensorflow as tf # 生成5*1的張量 var = tf.Variable(tf.random.normal([5, 1])) # 查詢張量中的索引為0和4的 ans = tf.nn.embedding_lookup(var, [0,4]) # 分別查詢張量中的索引為0、4以及1、2的 ans1 = tf.nn.embedding_lookup(var, [[0,4],[1,2]]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print('==================var=============================') print(sess.run(var)) print('==================ans=============================') print(sess.run(ans)) print('==================ans1=============================') print(sess.run(ans1)) ##########################輸出結果########################### ==================var============================= [[-0.10888958] [-0.94979066] [-0.7073568 ] [-0.86004704] [-0.1758791 ]] ==================ans============================= [[-0.10888958] [-0.1758791 ]] ==================ans1============================= [[[-0.10888958] [-0.1758791 ]] [[-0.94979066] [-0.7073568 ]]]