1. 程式人生 > 實用技巧 >tf.nn.embedding_lookup函式用法

tf.nn.embedding_lookup函式用法

tf.nn.embedding_lookup函式主要用於選取一個張量或者數組裡對應元素的值,即輸入一個索引,輸出該索引對應的值。
先看看引數

def embedding_lookup(
    params,
    ids,
    partition_strategy="mod",
    name=None,
    validate_indices=True,  # pylint: disable=unused-argument
    max_norm=None):

其中params跟ids比較重要,params即為資料來源,ids即索引。
詳見例項:

import tensorflow as tf

# 生成5*1的張量
var = tf.Variable(tf.random.normal([5, 1]))
# 查詢張量中的索引為0和4的
ans = tf.nn.embedding_lookup(var, [0,4])
# 分別查詢張量中的索引為0、4以及1、2的
ans1 = tf.nn.embedding_lookup(var, [[0,4],[1,2]])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('==================var=============================')
    print(sess.run(var))
    print('==================ans=============================')
    print(sess.run(ans))
    print('==================ans1=============================')
    print(sess.run(ans1))

##########################輸出結果###########################
==================var=============================
[[-0.10888958]
 [-0.94979066]
 [-0.7073568 ]
 [-0.86004704]
 [-0.1758791 ]]
==================ans=============================
[[-0.10888958]
 [-0.1758791 ]]
==================ans1=============================
[[[-0.10888958]
  [-0.1758791 ]]

 [[-0.94979066]
  [-0.7073568 ]]]