1. 程式人生 > >torch.nn.Embedding理解

torch.nn.Embedding理解

Pytorch官網的解釋是:一個儲存了固定字典和大小的簡單查詢表。這個模組常用來儲存詞嵌入和用下標檢索它們。模組的輸入是一個下標的列表,輸出是對應的詞嵌入。

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

個人理解:這是一個矩陣類,裡面初始化了一個隨機矩陣,矩陣的長是字典的大小,寬是用來表示字典中每個元素的屬性向量,向量的維度根據你想要表示的元素的複雜度而定。類例項化之後可以根據字典中元素的下標來查詢元素對應的向量。

輸入下標0,輸出就是embeds矩陣中第0行。

放程式碼:


除錯過程的引數:



用途:用作自然語言處理中作用很大

而對於一個詞,我們自己去想它的屬性不是很困難嗎,所以這個時候就可以交給神經網路了,我們只需要定義我們想要的維度,比如100,然後通過神經網路去學習它的每一個屬性的大小,而我們並不用關心到底這個屬性代表著什麼,我們只需要知道詞向量的夾角越小,表示他們之間的語義更加接近

http://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/