tensorflow使用預訓練詞向量
阿新 • • 發佈:2019-02-06
目前使用深度網路進行文字任務模型訓練時,第一步應該是將文字轉為詞向量進行處理。但一般詞向量的效果跟語料的大小有關,而處理任務的語料不足支援我們的實驗,這時就需要使用網上公開的大規模語料訓練詞向量。
1、下載
網上公開的詞向量下載地址:https://github.com/xgli/word2vec-api
glove的檔案說明如何使用預訓練詞向量,檔案格式如下:每行為一個單詞和其對應的詞向量,以空格分隔。
glove對應的詞向量,非二進位制檔案
word2vec對應的詞向量,非二進位制檔案
2、裝載
glove詞向量的裝載
filename = 'glove.6B.50d.txt'
def loadGloVe (filename):
vocab = []
embd = []
vocab.append('unk') #裝載不認識的詞
embd.append([0]*emb_size) #這個emb_size可能需要指定
file = open(filename,'r')
for line in file.readlines():
row = line.strip().split(' ')
vocab.append(row[0])
embd.append(row[1:])
print('Loaded GloVe!' )
file.close()
return vocab,embd
vocab,embd = loadGloVe(filename)
vocab_size = len(vocab)
embedding_dim = len(embd[0])
embedding = np.asarray(embd)
- word2vec詞向量的裝載
def loadWord2Vec(filename):
vocab = []
embd = []
cnt = 0
fr = open(filename,'r')
line = fr.readline().decode('utf-8' ).strip()
#print line
word_dim = int(line.split(' ')[1])
vocab.append("unk")
embd.append([0]*word_dim)
for line in fr :
row = line.strip().split(' ')
vocab.append(row[0])
embd.append(row[1:])
print "loaded word2vec"
fr.close()
return vocab,embd
vocab,embd = loadGloVe(filename)
vocab_size = len(vocab)
embedding_dim = len(embd[0])
embedding = np.asarray(embd)vocab:為詞表
embed:為詞的詞向量
3、詞向量層
構建網路時候的詞向量層
W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]),
trainable=False, name="W")
embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim])
embedding_init = W.assign(embedding_placeholder)
- 在網路結構中宣告詞向量矩陣W
sess.run(embedding_init, feed_dict={embedding_placeholder: embedding})
- 在將embedding傳給網路賦值。
4、詞表
此部分對某些任務不適用,比如對話,序列標註等問題,就是這個內建的函式會自動的過濾掉標點符號,但是標點符號也是一些任務需要的資訊。
tf.nn.embedding_lookup(W, input_x)
該程式碼將輸入對映為詞向量,但input_x為詞的id。因此我們需要將輸入文字對映為詞id序列。
from tensorflow.contrib import learn
#init vocab processor
vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
#fit the vocab from glove
pretrain = vocab_processor.fit(vocab)
#transform inputs
input_x = np.array(list(vocab_processor.transform(your_raw_input)))
- 使用tensorflow自帶的詞處理api進行處理,將詞對映成為詞id,同時會過濾掉標點符號。
目前寫這麼多,當時自己寫的時候,進了很多坑,這次寫的也不詳細,如果有不理解的,歡迎評論交流,或發郵件給我(郵件比較及時)。
原作者裡面是錯的,少考慮了“unk”這種情況。請大家注意。
轉自https://blog.csdn.net/lxg0807/article/details/72518962