1. 程式人生 > >Tensorflow學習---tf.nn.embedding_lookup

Tensorflow學習---tf.nn.embedding_lookup

tf.nn.embedding_lookup(params,ids, partition_strategy=’mod’, name=None, validate_indices=True,max_norm=None)
根據ids中的id,尋找params中的對應元素,可以理解為索引,所以ids中元素值不能超出params的第一維的維數值。
比如,ids=[1,3,5],則找出params中下標為1,3,5的向量組成一個矩陣返回。
引數說明:
params: 表示完整的embedding張量,或者除了第一維度之外具有相同形狀的P個張量的列表,表示經分割的嵌入張量。
ids: 一個型別為int32或int64的Tensor,包含要在params中查詢的id

下面是程式碼

#!/usr/bin/python
#encoding:utf-8

import tensorflow as tf
encode_embeddings = tf.constant([[1,2,3,4,5],[6,7,8,9,0]]) #2*5
# input_ids
中元素的值和encode_embeddings的第一維的維數有關,此例中為2維,input_ids只能是[0,2),也就是0和1
input_ids =tf.constant([[1,1,0],[1,0,1],[1,0, 1],[0,1, 1]])  #4*3
session = tf.Session()
with session.as_default():
    #

結果results是4*3*5矩陣。
   
results =tf.nn.embedding_lookup(encode_embeddings,input_ids)
    print(results.eval())# tf.eval()函式用於顯示張量tensor的值,但需要放在with session.as_default()中才能使用。
   
'''結果值
    [[[6 7 8 9 0]
  [6 7 8 9 0]
  [1 2 3 4 5]]

 [[6 7 8 9 0]
  [1 2 3 4 5]
  [6 7 8 9 0]]

 [[6 7 8 9 0]
  [1 2 3 4 5]
  [6 7 8 9 0]]

 [[1 2 3 4 5]
  [6 7 8 9 0]
  [6 7 8 9 0]]]'''