1. 程式人生 > >tf.nn.embedding_lookup

tf.nn.embedding_lookup

tps bottom ida 分享 none class ebp view ews

tf.nn.embedding_lookup記錄

技術分享圖片

tf.nn.embedding_lookup函數的用法主要是選取一個張量裏面索引對應的元素。tf.nn.embedding_lookup(tensor, id):tensor就是輸入張量,id就是張量對應的索引,其他的參數不介紹。

例如:

import tensorflow as tf;
import numpy as np;

c = np.random.random([10,1])
b = tf.nn.embedding_lookup(c, [1, 3])

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(b)
print c
輸出:
[[ 0.77505197]
[ 0.20635818]]
[[ 0.23976515]
[ 0.77505197]
[ 0.08798201]
[ 0.20635818]
[ 0.37183035]
[ 0.24753178]
[ 0.17718483]
[ 0.38533808]
[ 0.93345168]
[ 0.02634772]]

分析:輸出為張量的第一和第三個元素。
---------------------
作者:UESTC_C2_403
來源:CSDN
原文:https://blog.csdn.net/uestc_c2_403/article/details/72779417
版權聲明:本文為博主原創文章,轉載請附上博文鏈接!

slade_sal 2018.06.11 10:33* 字數 375 閱讀 5099評論 0 技術分享圖片

我覺得這張圖就夠了,實際上tf.nn.embedding_lookup的作用就是找到要尋找的embedding data中的對應的行下的vector。

tf.nn.embedding_lookup(params, ids, partition_strategy=‘mod‘, name=None, validate_indices=True, max_norm=None)

官方文檔位置,其中,params是我們給出的,可以通過:
1.tf.get_variable("item_emb_w", [self.item_count, self.embedding_size])等方式生產服從[0,1]的均勻分布或者標準分布
2.tf.convert_to_tensor轉化我們現有的array
然後,ids是我們要找的params中對應位置。

舉個例子:

import numpy as np
import tensorflow as tf
data = np.array([[[2],[1]],[[3],[4]],[[6],[7]]])
data = tf.convert_to_tensor(data)
lk = [[0,1],[1,0],[0,0]]
lookup_data = tf.nn.embedding_lookup(data,lk)
init = tf.global_variables_initializer()

先讓我們看下不同數據對應的維度:

In [76]: data.shape
Out[76]: (3, 2, 1)
In [77]: np.array(lk).shape
Out[77]: (3, 2)
In [78]: lookup_data
Out[78]: <tf.Tensor ‘embedding_lookup_8:0‘ shape=(3, 2, 2, 1) dtype=int64>

這個是怎麽做到的呢?關鍵的部分來了,看下圖:

技術分享圖片
lk中的值,在要尋找的embedding數據中下找對應的index下的vector進行拼接。永遠是look(lk)部分的維度+embedding(data)部分的除了第一維後的維度拼接。很明顯,我們也可以得到,lk裏面值是必須要小於等於embedding(data)的最大維度減一的。

以上的結果就是:

In [79]: data
Out[79]:
array([[[2],
        [1]],

       [[3],
        [4]],

       [[6],
        [7]]])

In [80]: lk
Out[80]: [[0, 1], [1, 0], [0, 0]]

# lk[0]也就是[0,1]對應著下面sess.run(lookup_data)的結果恰好是把data中的[[2],[1]],[[3],[4]]

In [81]: sess.run(lookup_data)
Out[81]:
array([[[[2],
         [1]],

        [[3],
         [4]]],


       [[[3],
         [4]],

        [[2],
         [1]]],


       [[[2],
         [1]],

        [[2],
         [1]]]])

最後,partition_strategy是用於當len(params) > 1,params的元素分割不能整分的話,則前(max_id + 1) % len(params)多分一個id.
當partition_strategy = ‘mod‘的時候,13個ids劃分為5個分區:[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]],也就是是按照數據列進行映射,然後再進行look_up操作。
當partition_strategy = ‘div‘的時候,13個ids劃分為5個分區:[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]],也就是是按照數據先後進行排序標序,然後再進行look_up操作。

tf.nn.embedding_lookup