1. 程式人生 > 實用技巧 >Alink漫談(十六) :Word2Vec原始碼分析 之 建立霍夫曼樹

Alink漫談(十六) :Word2Vec原始碼分析 之 建立霍夫曼樹

Alink漫談(十六) :Word2Vec原始碼分析 之 建立霍夫曼樹

目錄

0x00 摘要

Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式、流式演算法的機器學習平臺。本文和下文將帶領大家來分析Alink中 Word2Vec 的實現。

因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。

0x01 背景概念

1.1 詞向量基礎

1.1.1 獨熱編碼

one-hot編碼就是保證每個樣本中的單個特徵只有1位處於狀態1,其他的都是0。 具體編碼舉例如下,把語料庫中,杭州、上海、寧波、北京每個都對應一個向量,向量中只有一個值為1,其餘都為0。

杭州 [0,0,0,0,0,0,0,1,0,……,0,0,0,0,0,0,0]
上海 [0,0,0,0,1,0,0,0,0,……,0,0,0,0,0,0,0]
寧波 [0,0,0,1,0,0,0,0,0,……,0,0,0,0,0,0,0]
北京 [0,0,0,0,0,0,0,0,0,……,1,0,0,0,0,0,0]

其缺點是:

  • 向量的維度會隨著句子的詞的數量型別增大而增大;如果將世界所有城市名稱對應的向量合為一個矩陣的話,那這個矩陣過於稀疏,並且會造成維度災難。
  • 城市編碼是隨機的,向量之間相互獨立,無法表示語義層面上詞彙之間的相關資訊。

所以,人們想對獨熱編碼做如下改進:

  • 將vector每一個元素由整形改為浮點型,變為整個實數範圍的表示;
  • 轉化為低維度的連續值,也就是稠密向量。將原來稀疏的巨大維度壓縮嵌入到一個更小維度的空間。並且其中意思相近的詞將被對映到向量空間中相近的位置。

簡單說,要尋找一個空間對映,把高維詞向量嵌入到一個低維空間。然後就可以繼續處理

1.1.2 分散式表示

分散式表示(Distributed Representation)其實Hinton 最早在1986年就提出了,基本思想是將每個詞表達成 n 維稠密、連續的實數向量。而實數向量之間的關係可以代表詞語之間的相似度,比如向量的夾角cosine或者歐氏距離。

有一個專門的術語來描述詞向量的分散式表示技術——詞嵌入【word embedding】。從名稱上也可以看出來,獨熱編碼相當於對詞進行編碼,而分散式表示則是將詞從稀疏的大維度壓縮嵌入到較低維度的向量空間中。

Distributed representation 最大的貢獻就是讓相關或者相似的詞,在距離上更接近了。其核心思想是:上下文相似的詞,其語義也相似。這就是著名的詞空間模型(word space model)

Distributed representation 相較於One-hot方式另一個區別是維數下降極多,對於一個100W的詞表,我們可以用100維的實數向量來表示一個詞,而One-hot得要100W維。

為什麼對映到向量空間當中的詞向量就能表示是確定的哪個詞並且還能知道它們之間的相似度呢?

  • 關於為什麼能表示詞這個問題。分散式實際上就是求一個對映函式,這個對映函式將每個詞原始的one-hot表示壓縮投射到一個較低維度的空間並一一對應。所以分散式可以表示確定的詞。
  • 關於為什麼分散式還能知道詞之間的關係。就必須要了解分散式假設(distributional hypothesis)。 其基於的分散式假設就是出現在相同上下文(context)下的詞意思應該相近。所有學習word embedding的方法都是在用數學的方法建模詞和context之間的關係。

詞向量的分散式表示的核心思想由兩部分組成:

  • 選擇一種方式描述上下文;
  • 選擇一種模型刻畫目標詞與其上下文之間的關係。

事實上,不管是神經網路的隱層,還是多個潛在變數的概率主題模型,都是應用分散式表示。

1.2 CBOW & Skip-Gram

在word2vec出現之前,已經有用神經網路DNN來用訓練詞向量進而處理詞與詞之間的關係了。採用的方法一般是一個三層的神經網路結構(當然也可以多層),分為輸入層,隱藏層和輸出層(softmax層)。

這個模型是如何定義資料的輸入和輸出呢?一般分為CBOW(Continuous Bag-of-Words Model) 和 Skip-gram (Continuous Skip-gram Model)兩種模型。

1.2.1 CBOW

CBOW通過上下文來預測當前值。相當於一句話中扣掉一個詞,讓你猜這個詞是什麼。CBOW就是根據某個詞前面的C個詞或者前後C個連續的詞,來計算某個詞出現的概率。

CBOW的訓練過程如下:

  1. Input layer輸出層:是上下文單詞的one hot。假設單詞向量空間的維度為V,即整個詞庫corpus大小為V,上下文單詞視窗的大小為C。
  2. 假設最終詞向量的維度大小為N,則權值共享矩陣為W。W 的大小為 V * N,並且初始化。
  3. 假設語料中有一句話"我愛你"。如果我們現在關注"愛"這個詞,令C=2,則其上下文為"我",“你”。模型把"我" "你"的onehot形式作為輸入。易知其大小為1V。C個1V大小的向量分別跟同一個V * N大小的權值共享矩陣W相乘,得到的是C個1N大小的隱層hidden layer。
  4. C個1N大小的hidden layer取平均,得到一個1N大小的向量,即Hidden layer。
  5. 輸出權重矩陣 W’ 為N V,並進行相應的初始化工作。
  6. 將得到的Hidden layer向量 1N與 W’ 相乘,並且用softmax處理,得到1V的向量,此向量的每一維代表corpus中的一個單詞。概率中最大的index所代表的單詞為預測出的中間詞。
  7. 與groud truth中的one hot比較,求loss function的的極小值。
  8. 通過DNN的反向傳播演算法,我們可以求出DNN模型的引數,同時得到所有的詞對應的詞向量。這樣當我們有新的需求,要求出某8個詞對應的最可能的輸出中心詞時,我們可以通過一次DNN前向傳播演算法並通過softmax啟用函式找到概率最大的詞對應的神經元即可。

1.2.2 Skip-gram

Skip-gram用當前詞來預測上下文。相當於給你一個詞,讓你猜前面和後面可能出現什麼詞。即根據某個詞,然後分別計算它前後出現某幾個詞的各個概率。從這裡可以看出,對於每一個詞,Skip-gram要訓練C次,這裡C是預設的視窗大小,而CBOW只需要計算一次,因此CBOW計算量是Skip-gram的1/C,但也正因為Skip-gram同時擬合了C個詞,因此在避免過擬合上比CBOW效果更好,因此在訓練大型語料庫的時候,Skip-gram的效果比CBOW更好。

Skip-gram的訓練方法與CBOW如出一轍,唯一區別就是Skip-gram的輸入是單個詞的向量,而不是C個詞的求和平均。同時,訓練的話對於一箇中心詞,要訓練C次,每一次是一個不同的上下文詞,比如中心詞是北京,視窗詞是來到天安門這兩個,那麼Skip-gram要對北京-來到北京-天安門進行分別訓練。

目前的實現有一個問題:從隱藏層到輸出的softmax層的計算量很大,因為要計算所有詞的softmax概率,再去找概率最大的值。比如Vocab大小有10^5,那麼每算一個概率都要計算10^5次矩陣乘法,不現實。於是就引入了Word2vec。

1.3 Word2vec

1.3.1 Word2vec基本思想

所謂的語言模型,就是指對自然語言進行假設和建模,使得能夠用計算機能夠理解的方式來表達自然語言。word2vec採用的是n元語法模型(n-gram model),即假設一個詞只與周圍n個詞有關,而與文字中的其他詞無關。

如果 把詞當做特徵,那麼就可以把特徵對映到 K 維向量空間,可以為文字資料尋求更加深層次的特徵表示 。所以 Word2vec的基本思想是 通過訓練將每個詞對映成 K 維實數向量(K 一般為模型中的超引數),通過詞之間的距離(比如 cosine 相似度、歐氏距離等)來判斷它們之間的語義相似度。

其採用一個 三層的神經網路 ,輸入層-隱層-輸出層。有個核心的技術是 根據詞頻用Huffman編碼 ,使得所有詞頻相似的詞隱藏層啟用的內容基本一致,出現頻率越高的詞語,他們啟用的隱藏層數目越少,這樣有效的降低了計算的複雜度。

這個三層神經網路本身是 對語言模型進行建模 ,但也同時 獲得一種單詞在向量空間上的表示,而這個副作用才是Word2vec的真正目標

word2vec對之前的模型做了改進,

  • 首先,對於從輸入層到隱藏層的對映,沒有采取神經網路的線性變換加啟用函式的方法,而是採用簡單的對所有輸入詞向量求和並取平均的方法。比如輸入的是三個4維詞向量:(1,2,3,4),(9,6,11,8),(5,10,7,12),那麼我們word2vec對映後的詞向量就是(5,6,7,8)。由於這裡是從多個詞向量變成了一個詞向量。
  • 第二個改進就是從隱藏層到輸出的softmax層這裡的計算量個改進。為了避免要計算所有詞的softmax概率,word2vec取樣了霍夫曼樹來代替從隱藏層到輸出softmax層的對映。

1.3.2 Hierarchical Softmax基本思路

Word2vec計算可以用 層次Softmax演算法 ,這種演算法結合了Huffman編碼,其實藉助了分類問題中,使用一連串二分類近似多分類的思想。例如我們是把所有的詞都作為輸出,那麼“桔子”、“汽車”都是混在一起。給定w_t的上下文,先讓模型判斷w_t是不是名詞,再判斷是不是食物名,再判斷是不是水果,再判斷是不是“桔子”。

取一個適當大小的視窗當做語境,輸入層讀入視窗內的詞,將它們的向量(K維,初始隨機)加和在一起,形成隱藏層K個節點。輸出層是一個巨大的二叉樹,葉節點代表語料裡所有的詞(語料含有V個獨立的詞,則二叉樹有|V|個葉節點)。而這整顆二叉樹構建的演算法就是Huffman樹。

這樣,語料庫中的某個詞w_t 都對應著二叉樹的某個葉子節點,這樣每個詞 w 都可以從樹的根結點root沿著唯一一條路徑被訪問到,其路徑也就形成了其全域性唯一的二進位制編碼code,如"010011"。

不妨記左子樹為1,右子樹為0。接下來,隱層的每一個節點都會跟二叉樹的內節點有連邊,於是對於二叉樹的每一個內節點都會有K條連邊,每條邊上也會有權值。假設 n(w, j)為這條路徑上的第 j 個結點,且 L(w)為這條路徑的長度, j 從 1 開始編碼,即 n(w, 1)=root,n(w, L(w)) = w。對於第 j 個結點,層次 Softmax 定義的Label 為 1 - code[j]。

在訓練階段,當給定上下文,要預測後面的詞w_t的時候,我們就從二叉樹的根節點開始遍歷,這裡的目標就是預測這個詞的二進位制編號的每一位。即對於給定的上下文,我們的目標是使得預測詞的二進位制編碼概率最大。形象地說,對於 "010011",我們希望在根節點,詞向量和與根節點相連經過logistic計算得到bit=1的概率儘量接近0,在第二層,希望其bit=1的概率儘量接近1,這麼一直下去,我們把一路上計算得到的概率相乘,即得到目標詞w_t在當前網路下的概率P(w_t),那麼對於當前這個sample的殘差就是1-P(w_t),於是就可以使用梯度下降法訓練這個網路得到所有的引數值了。顯而易見,按照目標詞的二進位制編碼計算到最後的概率值就是歸一化的。

在訓練過程中,模型會賦予這些抽象的中間結點一個合適的向量,這個向量代表了它對應的所有子結點。因為真正的單詞公用了這些抽象結點的向量,所以Hierarchical Softmax方法和原始問題並不是等價的,但是這種近似並不會顯著帶來效能上的損失同時又使得模型的求解規模顯著上升。

1.3.3 Hierarchical Softmax 數學推導

傳統的Softmax可以看成是一個線性表,平均查詢時間O(n)。HS方法將Softmax做成一顆平衡的滿二叉樹,維護詞頻後,變成Huffman樹。

由於我們把之前所有都要計算的從輸出softmax層的概率計算變成了一顆二叉霍夫曼樹,那麼我們的softmax概率計算只需要沿著樹形結構進行就可以了。我們可以沿著霍夫曼樹從根節點一直走到我們的葉子節點的詞w2

和之前的神經網路語言模型相比,我們的霍夫曼樹的所有內部節點就類似之前神經網路隱藏層的神經元,其中,根節點的詞向量對應我們的投影后的詞向量,而所有葉子節點就類似於之前神經網路softmax輸出層的神經元,葉子節點的個數就是詞彙表的大小。在霍夫曼樹中,隱藏層到輸出層的softmax對映不是一下子完成的,而是沿著霍夫曼樹一步步完成的,因此這種softmax取名為"Hierarchical Softmax"。

如何“沿著霍夫曼樹一步步完成”呢?在word2vec中,我們採用了二元邏輯迴歸的方法,即規定沿著左子樹走,那麼就是負類(霍夫曼樹編碼1),沿著右子樹走,那麼就是正類(霍夫曼樹編碼0)。判別正類和負類的方法是使用sigmoid函式即:

\[P(+) = \sigma(x_w^T\theta) = \frac{1}{1+e^{-x_w^T\theta}} \]

其中xw是當前內部節點的詞向量,而θ則是我們需要從訓練樣本求出的邏輯迴歸的模型引數

使用霍夫曼樹有什麼好處呢?

  • 首先,由於是二叉樹,之前計算量為V,現在變成了log2V。
  • 第二,由於使用霍夫曼樹是高頻的詞靠近樹根,這樣高頻詞需要更少的時間會被找到,這符合我們的貪心優化思想。

容易理解,被劃分為左子樹而成為負類的概率為P(−)=1−P(+)。在某一個內部節點,要判斷是沿左子樹還是右子樹走的標準就是看P(−),P(+)誰的概率值大。而控制P(−),P(+)誰的概率值大的因素一個是當前節點的詞向量,另一個是當前節點的模型引數θ

對於上圖中的w2,如果它是一個訓練樣本的輸出,那麼我們期望對於裡面的隱藏節點n(w2,1)P(−)概率大,n(w2,2)P(−)概率大,n(w2,3)P(+)概率大。

回到基於Hierarchical Softmax的word2vec本身,我們的目標就是找到合適的所有節點的詞向量和所有內部節點θ, 使訓練樣本達到最大似然。

定義 w 經過的霍夫曼樹某一個節點j的邏輯迴歸概率為:

\[P(d_j^w|x_w, \theta_{j-1}^w)= \begin{cases} \sigma(x_w^T\theta_{j-1}^w)& {d_j^w=0}\\ 1-\sigma(x_w^T\theta_{j-1}^w) & {d_j^w = 1} \end{cases} \]

那麼對於某一個目標輸出詞w,其最大似然為:

\[\prod_{j=2}^{l_w}P(d_j^w|x_w, \theta_{j-1}^w) = \prod_{j=2}^{l_w} [\sigma(x_w^T\theta_{j-1}^w)] ^{1-d_j^w}[1-\sigma(x_w^T\theta_{j-1}^w)]^{d_j^w} \]

在word2vec中,由於使用的是隨機梯度上升法,所以並沒有把所有樣本的似然乘起來得到真正的訓練集最大似然,僅僅每次只用一個樣本更新梯度,這樣做的目的是減少梯度計算量。

可以求出的梯度表示式如下:

\[\frac{\partial L}{\partial x_w} = \sum\limits_{j=2}^{l_w}(1-d_j^w-\sigma(x_w^T\theta_{j-1}^w))\theta_{j-1}^w \]

有了梯度表示式,我們就可以用梯度上升法進行迭代來一步步的求解我們需要的所有的θwj−1和xw。

注意!word2vec要訓練兩組引數:一個是網路隱藏層的引數,一個是輸入單詞的引數(1 * dim)

在skip gram和CBOW中,中心詞詞向量在迭代過程中是不會更新的,只更新視窗詞向量,這個中心詞對應的詞向量需要下一次在作為非中心詞的時候才能進行迭代更新。

0x02 帶著問題閱讀

Alink的實現核心是以 https://github.com/tmikolov/word2vec 為基礎進行修改,實際上如果不是對C語言非常牴觸,建議先閱讀這個程式碼。因為Alink的並行處理程式碼真的挺難理解,尤其是資料預處理部分。

以問題為導向:

  • 哪些模組用到了Alink的分散式處理能力?
  • Alink實現了Word2vec的哪個模型?是CBOW模型還是skip-gram模型?
  • Alink用到了哪個優化方法?是Hierarchical Softmax?還是Negative Sampling?
  • 是否在本演算法內去除停詞?所謂停用詞,就是出現頻率太高的詞,如逗號,句號等等,以至於沒有區分度。
  • 是否使用了自適應學習率?

0x03 示例程式碼

我們把Alink的測試程式碼修改下。需要說明的是Word2vec也吃記憶體,所以我的機器上需要配置VM啟動引數:-Xms256m -Xmx640m -XX:PermSize=128m -XX:MaxPermSize=512m

public class Word2VecTest {
    public static void main(String[] args) throws Exception {
        TableSchema schema = new TableSchema(
                new String[] {"docid", "content"},
                new TypeInformation <?>[] {Types.LONG(), Types.STRING()}
        );
        List <Row> rows = new ArrayList <>();
        rows.add(Row.of(0L, "老王 是 我們 團隊 裡 最胖 的"));
        rows.add(Row.of(1L, "老黃 是 第二 胖 的"));
        rows.add(Row.of(2L, "胖"));
        rows.add(Row.of(3L, "胖 胖 胖"));

        MemSourceBatchOp source = new MemSourceBatchOp(rows, schema);

        Word2Vec word2Vec = new Word2Vec()
                .setSelectedCol("content")
                .setOutputCol("output")
                .setMinCount(1);

        List<Row> result = word2Vec.fit(source).transform(source).collect();
        System.out.println(result);
    }
}

程式輸出是

[0,老王 是 我們 團隊 裡 最胖 的,0.8556591824716802 0.4185472857807756 0.5917632873908979 0.445803358747732 0.5351499521578621 0.6559828965377957 0.5965739474021792 0.473846881662404 0.516117276817363 0.3434555277582306 0.38403383919352685 ..., 
 
1,老黃 是 第二 胖 的,0.9227240557894372 0.5697617202790405 0.42338677208067105 0.5483285740408497 0.5950012315151869 0.4155926470754411 0.6283449603326386 0.47098108241809644 0.2874100346124693 0.41205111525453264 0.59972461077888 ..., 
 
3,胖 胖 胖,0.9220798404216994 0.8056990255747927 0.166767439210223 0.1651382099869762 0.7498624766177563 0.12363837145024788 0.16301554444226507 0.5992360550912706 0.6408649011941911 0.5504539398019214 0.4935531765920934 0.13805809361251292 0.2869384374291237 0.47796081976004645 0.6305720374272978 0.1745491550099714 ...]

0x04 整體邏輯

4.1 Word2vec大概流程

  1. 分詞 / 詞幹提取和詞形還原。 中文和英文的nlp各有各的難點,中文的難點在於需要進行分詞,將一個個句子分解成一個單詞陣列。而英文雖然不需要分詞,但是要處理各種各樣的時態,所以要進行詞幹提取和詞形還原。
  2. 構造詞典,統計詞頻。這一步需要遍歷一遍所有文字,找出所有出現過的詞,並統計各詞的出現頻率。
  3. 構造樹形結構。依照出現概率構造Huffman樹。如果是完全二叉樹,則簡單很多。需要注意的是,所有分類都應該處於葉節點。
  4. 生成節點所在的二進位制碼。這個二進位制碼反映了節點在樹中的位置,就像門牌號一樣,能按照編碼從根節點一步步找到對應的葉節點。
  5. 初始化各非葉節點的中間向量和葉節點中的詞向量。樹中的各個節點,都儲存著一個長為m的向量,但葉節點和非葉結點中的向量的含義不同。葉節點中儲存的是各詞的詞向量,是作為神經網路的輸入的。而非葉結點中儲存的是中間向量,對應於神經網路中隱含層的引數,與輸入一起決定分類結果。
  6. 訓練中間向量和詞向量。對於CBOW模型,首先將某詞A附近的n-1個詞的詞向量相加作為系統的輸入,並且按照詞A在步驟4中生成的二進位制碼,一步步的進行分類並按照分類結果訓練中間向量和詞向量。舉個栗子,對於某節點,我們已經知道其二進位制碼是100。那麼在第一個中間節點應該將對應的輸入分類到右邊。如果分類到左邊,則表明分類錯誤,需要對向量進行修正。第二個,第三個節點也是這樣,以此類推,直到達到葉節點。因此對於單個單詞來說,最多隻會改動其路徑上的節點的中間向量,而不會改動其他節點。

4.2 訓練程式碼

Word2VecTrainBatchOp 類是訓練的程式碼實現,其linkFrom函式體現了程式的總體邏輯,其省略版程式碼如下,具體後期我們會一一詳述。

  public Word2VecTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final int vectorSize = getVectorSize();
    
    // 計算單詞出現次數
    DataSet <Row> wordCnt = WordCountUtil
      .splitDocAndCount(in, getSelectedCol(), getWordDelimiter())
      .filter("cnt >= " + String.valueOf(getMinCount()))
      .getDataSet();

    // 根據詞頻對單詞進行排序
    DataSet <Row> sorted = sortedIndexVocab(wordCnt);
    // 計算排序之後單詞數目
    DataSet <Long> vocSize = DataSetUtils
      .countElementsPerPartition(sorted)
      .sum(1)
      .map(new MapFunction <Tuple2 <Integer, Long>, Long>() {
        @Override
        public Long map(Tuple2 <Integer, Long> value) throws Exception {
          return value.f1;
        }
      });
    // 建立字典和二叉樹
    DataSet <Tuple3 <Integer, String, Word>> vocab = sorted
      .reduceGroup(new CreateVocab())
      .withBroadcastSet(vocSize, "vocSize")
      .rebalance();
    // 再次分割單詞
    DataSet <String[]> split = in
      .select("`" + getSelectedCol() + "`")
      .getDataSet()
      .flatMap(new WordCountUtil.WordSpliter(getWordDelimiter()))
      .rebalance();
    // 生成訓練資料
    DataSet <int[]> trainData = encodeContent(split, vocab)
      .rebalance();

    final long seed = System.currentTimeMillis();
    // 獲取精簡詞典
    DataSet <Tuple2 <Integer, Word>> vocabWithoutWordStr = vocab
      .map(new UseVocabWithoutWordString());
    
    // 初始化模型
    DataSet <Tuple2 <Integer, double[]>> initialModel = vocabWithoutWordStr
      .mapPartition(new initialModel(seed, vectorSize))
      .rebalance();
    // 計算迭代次數
    DataSet <Integer> syncNum = DataSetUtils
      .countElementsPerPartition(trainData)
      .sum(1)
      .map(new MapFunction <Tuple2 <Integer, Long>, Integer>() {
        @Override
        public Integer map(Tuple2 <Integer, Long> value) throws Exception {
          return Math.max((int) (value.f1 / 100000L), 5);
        }
      });
    
    // 迭代訓練
    DataSet <Row> model = new IterativeComQueue()
      .initWithPartitionedData("trainData", trainData)
      .initWithBroadcastData("vocSize", vocSize)
      .initWithBroadcastData("initialModel", initialModel)
      .initWithBroadcastData("vocabWithoutWordStr", vocabWithoutWordStr)
      .initWithBroadcastData("syncNum", syncNum)
      .add(new InitialVocabAndBuffer(getParams()))
      .add(new UpdateModel(getParams()))
      .add(new AllReduce("input"))
      .add(new AllReduce("output"))
      .add(new AvgInputOutput())
      .setCompareCriterionOfNode0(new Criterion(getParams()))
      .closeWith(new SerializeModel(getParams()))
      .exec();
    
    // 輸出模型
    model = model
      .map(new MapFunction <Row, Tuple2 <Integer, DenseVector>>() {
        @Override
        public Tuple2 <Integer, DenseVector> map(Row value) throws Exception {
          return Tuple2.of((Integer) value.getField(0), (DenseVector) value.getField(1));
        }
      })
      .join(vocab)
      .where(0)
      .equalTo(0)
      .with(new JoinFunction <Tuple2 <Integer, DenseVector>, Tuple3 <Integer, String, Word>, Row>() {
        @Override
        public Row join(Tuple2 <Integer, DenseVector> first, Tuple3 <Integer, String, Word> second)
          throws Exception {
          return Row.of(second.f1, first.f1);
        }
      })
      .mapPartition(new MapPartitionFunction <Row, Row>() {
        @Override
        public void mapPartition(Iterable <Row> values, Collector <Row> out) throws Exception {
          Word2VecModelDataConverter model = new Word2VecModelDataConverter();

          model.modelRows = StreamSupport
            .stream(values.spliterator(), false)
            .collect(Collectors.toList());

          model.save(model, out);
        }
      });

    setOutput(model, new Word2VecModelDataConverter().getModelSchema());

    return this;
  }

0x05 處理輸入

此部分是最複雜的,也是和 C 程式碼 差異最大的地方。因為Alink需要考慮處理大規模輸入資料,所以進行了分散式處理,而一旦分散式處理,就會各種細節糾纏在一起。

5.1 計算單詞出現次數

這部分程式碼如下,具體又分為兩個部分。

DataSet <Row> wordCnt = WordCountUtil
      .splitDocAndCount(in, getSelectedCol(), getWordDelimiter())
      .filter("cnt >= " + String.valueOf(getMinCount()))
      .getDataSet();

5.1.1 分割單詞&計數

此處邏輯相對清晰,就是 分割單詞 splitDoc, 然後計數 count。

public static BatchOperator<?> splitDocAndCount(BatchOperator<?> input, String docColName, String wordDelimiter) {
  return count(splitDoc(input, docColName, wordDelimiter), WORD_COL_NAME, COUNT_COL_NAME);
}
5.1.1.1 分割單詞

分割單詞使用 DocWordSplitCount 這個UDTF。

public static BatchOperator splitDoc(BatchOperator<?> input, String docColName, String wordDelimiter) {
  return input.udtf(
    docColName,
    new String[] {WORD_COL_NAME, COUNT_COL_NAME},
    new DocWordSplitCount(wordDelimiter),
    new String[] {}
  );
}

DocWordSplitCount的功能就是分割單詞,計數。

public class DocWordSplitCount extends TableFunction <Row> {

  private String delimiter;

  public DocWordSplitCount(String delimiter) {
    this.delimiter = delimiter;
  }

  public void eval(String content) {
    String[] words = content.split(this.delimiter); // 分割單詞
    HashMap <String, Long> map = new HashMap <>(0);

    for (String word : words) {
      if (word.length() > 0) {
        map.merge(word, 1L, Long::sum); // 計數
      }
    }

    for (Map.Entry <String, Long> entry : map.entrySet()) {
      collect(Row.of(entry.getKey(), entry.getValue())); // 傳送二元組<單詞,個數>
    }
  }
}

// runtime時候,變數如下:
content = "老王 是 我們 團隊 裡 最胖 的"
words = {String[7]@10021} 
 0 = "老王"
 1 = "是"
 2 = "我們"
 3 = "團隊"
 4 = "裡"
 5 = "最胖"
 6 = "的"
map = {HashMap@10024}  size = 7
 "最胖" -> {Long@10043} 1
 "的" -> {Long@10043} 1
 "裡" -> {Long@10043} 1
 "老王" -> {Long@10043} 1
 "團隊" -> {Long@10043} 1
 "我們" -> {Long@10043} 1
 "是" -> {Long@10043} 1
5.1.1.2 計數

此處會把分散式計算出來的 二元組<單詞,個數> 做 groupBy,這樣就得到了最終的 單詞出現次數。其中 Flink 的groupBy起到了關鍵作用,大家有興趣可以閱讀 [ 原始碼解析] Flink的groupBy和reduce究竟做了什麼

public static BatchOperator count(BatchOperator input, String wordColName) {
    return count(input, wordColName, null);
}

public static BatchOperator count(BatchOperator input, String wordColName, String wordValueColName) {
    if (null == wordValueColName) {
      return input.groupBy(wordColName,
        wordColName + " AS " + WORD_COL_NAME + ", COUNT(" + wordColName + ") AS " + COUNT_COL_NAME);
    } else {
      return input.groupBy(wordColName,
        wordColName + " AS " + WORD_COL_NAME + ", SUM(" + wordValueColName + ") AS " + COUNT_COL_NAME);
    }
}

5.1.2 過濾低頻詞

如果單詞出現次數太少,就沒有加入字典的必要,所以需要過濾。

5.1.2.1 配置

Word2VecTrainBatchOp 需要實現配置引數 Word2VecTrainParams,具體如下:

public interface Word2VecTrainParams<T> extends
    HasNumIterDefaultAs1<T>,
  HasSelectedCol <T>,
  HasVectorSizeDv100 <T>,
  HasAlpha <T>,
  HasWordDelimiter <T>,
  HasMinCount <T>,
  HasRandomWindow <T>,
  HasWindow <T> {
}

其中 HasMinCount 就是用來配置低頻單詞的閾值。

public interface HasMinCount<T> extends WithParams<T> {
  ParamInfo <Integer> MIN_COUNT = ParamInfoFactory
    .createParamInfo("minCount", Integer.class)
    .setDescription("minimum count of word")
    .setHasDefaultValue(5)
    .build();

  default Integer getMinCount() {
    return get(MIN_COUNT);
  }

  default T setMinCount(Integer value) {
    return set(MIN_COUNT, value);
  }
}

在例項程式碼中有如下,就是設定最低閾值是 1,這是因為我們的輸入很少,不會過濾低頻詞。如果詞彙量多,可以設定為 5。

.setMinCount(1);
5.1.2.2 過濾

我們再取出使用程式碼.

DataSet <Row> wordCnt = WordCountUtil
      .splitDocAndCount(in, getSelectedCol(), getWordDelimiter())
      .filter("cnt >= " + String.valueOf(getMinCount()))
      .getDataSet();

可以看到,.filter("cnt >= " + String.valueOf(getMinCount())) 這部分是過濾。這是簡單的SQL用法。

然後會返回 DataSet wordCnt。

5.2 依據詞頻對單詞排序

過濾低頻單詞之後,會對得到的單詞進行排序。

DataSet <Row> sorted = sortedIndexVocab(wordCnt);

此處比較艱深晦澀,需要仔細梳理,大致邏輯是:

  • 1)使用 SortUtils.pSort 對<單詞,頻次> 進行大規模並行排序;
  • 2)對 上一步的返回值 f0 進行分割槽 sorted.f0.partitionCustom , 因為上一步返回值的 f0 是 <partition id, Row> ,得倒資料集 partitioned。
  • 3)計算每個分割槽的單詞數目 countElementsPerPartition(partitioned) ; 得倒 Tuple2 ; 得倒的結果資料集 cnt 會廣播出來,下一步計算時候會用到;
  • 4)在各個分割槽內(就是第二步得倒的資料集 partitioned)利用 mapPartition 對單詞進行排序,利用到了上步的 cnt ;
    • open 函式中,會計算 本分割槽內 所有單詞的總數total、本區單詞數目curLen,本區單詞起始位置 start
    • mapPartition 函式中,會排序,歸併,最後發出資料集 DataSet

注1,pSort 可以參見 Alink漫談(六) : TF-IDF演算法的實現。SortUtils.pSort是大規模並行排序。pSort返回值是: @return f0: dataset which is indexed by partition id, f1: dataset which has partition id and count.

具體實現如下:

private static DataSet <Row> sortedIndexVocab(DataSet <Row> vocab) {
    final int sortIdx = 1;
    Tuple2 <DataSet <Tuple2 <Integer, Row>>, DataSet <Tuple2 <Integer, Long>>> sorted
      = SortUtils.pSort(vocab, sortIdx); // 進行大規模並行排序

    DataSet <Tuple2 <Integer, Row>> partitioned = sorted.f0.partitionCustom(new Partitioner <Integer>() {
      @Override
      public int partition(Integer key, int numPartitions) {
        return key; // 利用分割槽 idx 進行分割槽
      }
    }, 0);

    DataSet <Tuple2 <Integer, Long>> cnt = DataSetUtils.countElementsPerPartition(partitioned);

    return partitioned.mapPartition(new RichMapPartitionFunction <Tuple2 <Integer, Row>, Row>() {
      int start;
      int curLen;
      int total;

      @Override
      public void open(Configuration parameters) throws Exception {
        List <Tuple2 <Integer, Long>> cnts = getRuntimeContext().getBroadcastVariable("cnt");
        int taskId = getRuntimeContext().getIndexOfThisSubtask();
        start = 0;
        curLen = 0;
        total = 0;

        for (Tuple2 <Integer, Long> val : cnts) {
          if (val.f0 < taskId) {
            start += val.f1; // 本區單詞起始位置 
          }

          if (val.f0 == taskId) {  // 只計算本分割槽對應的記錄,因為 f0 是分割槽idx
            curLen = val.f1.intValue(); // 本區單詞數目curLen
          }

          total += val.f1.intValue(); // 得倒 本分割槽內 所有單詞的總數total
        }
                
// runtime 列印如下                
val = {Tuple2@10585} "(7,0)"
 f0 = {Integer@10586} 7
 f1 = {Long@10587} 0                
                
      }

      @Override
      public void mapPartition(Iterable <Tuple2 <Integer, Row>> values, Collector <Row> out) throws Exception {

        Row[] all = new Row[curLen];

        int i = 0;
        for (Tuple2 <Integer, Row> val : values) {
          all[i++] = val.f1; // 得倒所有的單詞
        }

        Arrays.sort(all, (o1, o2) -> (int) ((Long) o1.getField(sortIdx) - (Long) o2.getField(sortIdx))); // 排序

        i = start;
        for (Row row : all) {
          // 歸併 & 傳送
          out.collect(RowUtil.merge(row, -(i - total + 1)));
          ++i;
        }
                
// runtime時的變數如下:                
all = {Row[2]@10655} 
 0 = {Row@13346} "我們,1"
 1 = {Row@13347} "裡,1"
i = 0
total = 10
start = 0
      }
    }).withBroadcastSet(cnt, "cnt"); // 廣播進來的變數
}

5.2.1 排序後單詞數目

此處是計算排序後每個分割槽的單詞數目,相對邏輯簡單,其結果資料集 會廣播出來給下一步使用。

DataSet <Long> vocSize = DataSetUtils // vocSize是詞彙的個數
      .countElementsPerPartition(sorted)
      .sum(1) // 累計第一個key
      .map(new MapFunction <Tuple2 <Integer, Long>, Long>() {
        @Override
        public Long map(Tuple2 <Integer, Long> value) throws Exception {
          return value.f1;
        }
      });

5.3 建立詞典&二叉樹

本部分會利用上兩步得倒的結果:"排序好的單詞"&"每個分割槽的單詞數目" 來建立 詞典 和 二叉樹。

DataSet <Tuple3 <Integer, String, Word>> vocab = sorted // 排序後的單詞資料集
      .reduceGroup(new CreateVocab())
      .withBroadcastSet(vocSize, "vocSize") // 廣播上一步產生的結果集
      .rebalance();

CreateVocab 完成了具體工作,結果集是:Tuple3<單詞在詞典的idx,單詞,單詞在詞典中對應的元素>。

private static class CreateVocab extends RichGroupReduceFunction <Row, Tuple3 <Integer, String, Word>> {
    int vocSize;

    @Override
    public void open(Configuration parameters) throws Exception {
      vocSize = getRuntimeContext().getBroadcastVariableWithInitializer("vocSize",
        new BroadcastVariableInitializer <Long, Integer>() {
          @Override
          public Integer initializeBroadcastVariable(Iterable <Long> data) {
            return data.iterator().next().intValue();
          }
        });
    }

    @Override
    public void reduce(Iterable <Row> values, Collector <Tuple3 <Integer, String, Word>> out) throws Exception {
      String[] words = new String[vocSize];
      Word[] vocab = new Word[vocSize];

            // 建立詞典
      for (Row row : values) {
        Word word = new Word();
        word.cnt = (long) row.getField(1);
        vocab[(int) row.getField(2)] = word;
        words[(int) row.getField(2)] = (String) row.getField(0);
      }

// runtime變數如下
words = {String[10]@10606} 
 0 = "胖"
 1 = "的"
 2 = "是"
 3 = "團隊"
 4 = "老王"
 5 = "第二"
 6 = "最胖"
 7 = "老黃"
 8 = "裡"
 9 = "我們"            
            
      // 建立二叉樹,建立過程中會更新詞典內容
      createBinaryTree(vocab);

// runtime變數如下            
vocab = {Word2VecTrainBatchOp$Word[10]@10669} 
 0 = {Word2VecTrainBatchOp$Word@13372} 
  cnt = 5
  point = {int[2]@13382} 
   0 = 8
   1 = 7
  code = {int[2]@13383} 
   0 = 1
   1 = 1
 1 = {Word2VecTrainBatchOp$Word@13373} 
  cnt = 2
  point = {int[3]@13384} 
   0 = 8
   1 = 7
   2 = 5
  code = {int[3]@13385} 
   0 = 1
   1 = 0
   2 = 1            
            
      for (int i = 0; i < vocab.length; ++i) {
        // 結果集是:Tuple3<單詞在詞典的idx,單詞,單詞對應的詞典元素>
        out.collect(Tuple3.of(i, words[i], vocab[i]));
      }        
    }
}

5.3.1 資料結構

詞典的資料結構如下:

private static class Word implements Serializable {
  public long cnt; // 詞頻,左右兩個輸入節點的詞頻之和
  public int[] point; //在樹中的節點序列, 即從根結點到葉子節點的路徑
  public int[] code; //霍夫曼碼, HuffmanCode
}

一個容易混淆的地方:

  • vocab[word].code[d] 指的是,當前單詞word的,第d個編碼,編碼不含Root結點
  • vocab[word].point[d] 指的是,當前單詞word,第d個編碼下,前置結點。

比如vocab[word].point[0]肯定是Root結點,而 vocab[word].code[0]肯定是Root結點走到下一個點的編碼。

5.3.2 建立二叉樹

這裡基於語料訓練樣本建立霍夫曼樹(基於詞頻)。

Alink這裡基本就是c語言的java實現。可能很多兄弟還不熟悉,所以需要講解下。

Word2vec 利用陣列下標的移動就完成了構建、編碼。它最重要的是隻用了parent這個陣列來標記生成的Parent結點( 範圍 VocabSize,VocabSize∗2−2 )。最後對Parent結點減去VocabSize,得到從0開始的Point路徑陣列。

基本套路是:

  • 首先,設定兩個指標pos1和pos2,分別指向最後一個詞和最後一個詞的後一位;
  • 然後,從兩個指標所指的數中選擇出最小的值。記為min1i。如pos1所指的值最小,此時,將pos1左移,再比較 pos1和pos2所指的數。選擇出最小的值,記為min2i,將他們的和儲存到pos2所指的位置。
  • 並將此時pos2所指的位置設定為min1i和min2i的父節點,同一時候,記min2i所指的位置的編碼為1。
private static void createBinaryTree(Word[] vocab) {
    int vocabSize = vocab.length;

    int[] point = new int[MAX_CODE_LENGTH];
    int[] code = new int[MAX_CODE_LENGTH];
        // 首先定義了3個長度為vocab_size*2+1的陣列
        // count陣列中前vocab_size儲存的是每個詞的相應的詞頻。後面初始化的是非常大的數,已知詞庫中的詞是依照降序排列的。
    long[] count = new long[vocabSize * 2 - 1];
    int[] binary = new int[vocabSize * 2 - 1];
    int[] parent = new int[vocabSize * 2 - 1];

      // 前半部分初始化為每個詞出現的次數
    for (int i = 0; i < vocabSize; ++i) {
      count[i] = vocab[i].cnt;
    }
    // 後半部分初始化為一個固定的常數
    Arrays.fill(count, vocabSize, vocabSize * 2 - 1, Integer.MAX_VALUE);

    // pos1, pos2 可以理解為 下一步 將要構建的左右兩個節點
    // min1i, min2i 是當前正在構建的左右兩個節點
    int min1i, min2i, pos1, pos2;

    pos1 = vocabSize - 1; // pos1指向前半截的尾部
    pos2 = vocabSize; // pos2指向後半截的開始

    // 每次增加一個節點,構建Huffman樹
    for (int a = 0; a < vocabSize - 1; ++a) {
      // First, find two smallest nodes 'min1, min2'
      // 選擇最小的節點min1
      // 根據pos1, pos2找到目前的 左 min1i 的位置,並且調整下一次的pos1, pos2
      if (pos1 >= 0) {
        if (count[pos1] < count[pos2]) {
          min1i = pos1;
          pos1--;
        } else {
          min1i = pos2;
          pos2++;
        }
      } else {
        min1i = pos2;
        pos2++;
      }
            
      // 選擇最小的節點min2
      // 根據上一步調整的pos1, pos2找到目前的 右 min2i 的位置,並且調整下一次的pos1, pos2
      if (pos1 >= 0) {
        if (count[pos1] < count[pos2]) {
          min2i = pos1;
          pos1--;
        } else {
          min2i = pos2;
          pos2++;
        }
      } else {
        min2i = pos2;
        pos2++;
      }

      // 新生成的節點的概率是兩個輸入節點的概率之和,其左右子節點即為輸入的兩個節點。值得注意的是,新生成的節點肯定不是葉節點,而非葉結點的value值是中間向量,初始化為零向量。
      count[vocabSize + a] = count[min1i] + count[min2i];
      parent[min1i] = vocabSize + a; // 設定父節點
      parent[min2i] = vocabSize + a;
      binary[min2i] = 1;  // 設定一個子樹的編碼為1
    }
    
// runtime變數如下:
binary = {int[19]@13405}  0 = 1 1 = 1 2 = 0 3 = 0 4 = 1 5 = 0 6 = 1 7 = 0 8 = 1 9 = 0 10 = 1 11 = 0 12 = 1 13 = 0 14 = 1 15 = 0 16 = 0 17 = 1 18 = 0
    
parent = {int[19]@13406}  0 = 17 1 = 15 2 = 15 3 = 13 4 = 12 5 = 12 6 = 11 7 = 11 8 = 10 9 = 10 10 = 13 11 = 14 12 = 14 13 = 16 14 = 16 15 = 17 16 = 18 17 = 18 18 = 0    
    
count = {long[19]@13374}  0 = 5 1 = 2 2 = 2 3 = 1 4 = 1 5 = 1 6 = 1 7 = 1 8 = 1 9 = 1 10 = 2 11 = 2 12 = 2 13 = 3 14 = 4 15 = 4 16 = 7 17 = 9 18 = 16    
    
      // Now assign binary code to each vocabulary word
      // 生成Huffman碼,即找到每一個字的code,和對應的在樹中的節點序列,在生成Huffman編碼的過程中。針對每個詞(詞都在葉子節點上),從葉子節點開始。將編碼存入到code陣列中,如對於上圖中的“R”節點來說。其code陣列為{1,0}。再對其反轉便是Huffman編碼:
    for (int a = 0; a < vocabSize; ++a) { // 為每一個詞分配二進位制編碼,即Huffman編碼
      int b = a;
      int i = 0;

      do {
        code[i] = binary[b]; // 找到當前的節點的編碼
        point[i] = b; // 記錄從葉子節點到根結點的序列
        i++;
        b = parent[b]; // 找到當前節點的父節點
      } while (b != vocabSize * 2 - 2); // 已經找到了根結點,根節點是沒有編碼的

      vocab[a].code = new int[i];

      for (b = 0; b < i; ++b) {
        vocab[a].code[i - b - 1] = code[b]; // 編碼的反轉
      }

      vocab[a].point = new int[i];
      vocab[a].point[0] = vocabSize - 2;
      for (b = 1; b < i; ++b) {
        vocab[a].point[i - b] = point[b] - vocabSize; // 記錄的是從根結點到葉子節點的路徑
      }
    }
}

最終二叉樹結果如下:

vocab = {Word2VecTrainBatchOp$Word[10]@10608} 
 0 = {Word2VecTrainBatchOp$Word@13314} 
  cnt = 5
  point = {int[2]@13329} 
   0 = 8
   1 = 7
  code = {int[2]@13330} 
   0 = 1
   1 = 1
 1 = {Word2VecTrainBatchOp$Word@13320} 
  cnt = 2
  point = {int[3]@13331} 
   0 = 8
   1 = 7
   2 = 5
  code = {int[3]@13332} 
   0 = 1
   1 = 0
   2 = 1
 2 = {Word2VecTrainBatchOp$Word@13321} 
 3 = {Word2VecTrainBatchOp$Word@13322} 
 ......
 9 = {Word2VecTrainBatchOp$Word@13328} 

5.4 分割單詞

此處會再次對原始輸入做單詞分割,這裡總感覺是可以把此步驟和前面步驟放在一起做優化。

DataSet <String[]> split = in
      .select("`" + getSelectedCol() + "`")
      .getDataSet()
      .flatMap(new WordCountUtil.WordSpliter(getWordDelimiter()))
      .rebalance();

5.5 生成訓練資料

生成訓練資料程式碼如下,此處也比較晦澀。

DataSet <int[]> trainData = encodeContent(split, vocab).rebalance();

最終目的是,把每個句子都翻譯成了一個詞典idx的序列,比如:

原始輸入 : "老王 是 我們 團隊 裡 最胖 的"

編碼之後 : “4,1,9,3,8,6,2” , 這裡每個數字是 本句子中每個單詞在詞典中的序列號。

encodeContent 的輸入是:

  • 已經分割好的原始輸入(其實本文示例中的原始輸入就是用空格分隔的),對於encodeContent 來說就是一個一個句子;
  • 詞典資料集 Tuple3<單詞在詞典的idx,單詞,單詞在詞典中對應的元素>;

流程邏輯如下:

  • 對輸入的句子分割槽處理 content.mapPartition,得到資料集 Tuple4 <>(taskId, localCnt, i, val[i]),分別是 Tuple4 <>(taskId, 本分割槽句子數目, 本單詞在本句子中的idx, 本單詞),所以此處傳送的核心是單詞。
  • 使用了 Flink coGroup 功能完成了雙流匹配合並功能,將單詞流和詞典篩選合併(where(3).equalTo(1)),其中上步處理中,f3是word,vocab.f1 是word,所以就是在兩個流中找到相同的單詞然後做操作。得倒 Tuple4.of(tuple.f0, tuple.f1, tuple.f2, row.getField(0))),即 結果集是 Tuple4 <taskId, 本分割槽句子數目, 本單詞在本句子中的idx,單詞在詞典的idx>
  • 分組排序,歸併 groupBy(0, 1).reduceGroup,然後排序(根據本單詞在本句子中的idx來排序);結果集是 DataSet <int[]>,即返回 “本單詞在詞典的idx”,比如 [4,1,9,3,8,6,2] 。就是本句子中每個單詞在詞典中的序列號。

具體程式碼如下:

private static DataSet <int[]> encodeContent(
    DataSet <String[]> content,
    DataSet <Tuple3 <Integer, String, Word>> vocab) {
    return content
      .mapPartition(new RichMapPartitionFunction <String[], Tuple4 <Integer, Long, Integer, String>>() {
        @Override
        public void mapPartition(Iterable <String[]> values,
                     Collector <Tuple4 <Integer, Long, Integer, String>> out)
          throws Exception {
          int taskId = getRuntimeContext().getIndexOfThisSubtask();
          long localCnt = 0L;
          for (String[] val : values) {
            if (val == null || val.length == 0) {
              continue;
            }

            for (int i = 0; i < val.length; ++i) {
              // 核心是傳送單詞
              out.collect(new Tuple4 <>(taskId, localCnt, i, val[i]));
            }

            ++localCnt; // 這裡注意,傳送時候 localCnt 還沒有更新

// runtime 的資料如下:
val = {String[7]@10008} 
 0 = "老王"
 1 = "是"
 2 = "我們"
 3 = "團隊"
 4 = "裡"
 5 = "最胖"
 6 = "的"                    
                    }
        }
      }).coGroup(vocab)
      .where(3) // 上步處理中,f3是word
      .equalTo(1) // vocab.f1 是word
      .with(new CoGroupFunction <Tuple4 <Integer, Long, Integer, String>, Tuple3 <Integer, String, Word>,
        Tuple4 <Integer, Long, Integer, Integer>>() {
        @Override
        public void coGroup(Iterable <Tuple4 <Integer, Long, Integer, String>> first,
                  Iterable <Tuple3 <Integer, String, Word>> second,
                  Collector <Tuple4 <Integer, Long, Integer, Integer>> out) {
          for (Tuple3 <Integer, String, Word> row : second) {
            for (Tuple4 <Integer, Long, Integer, String> tuple : first) {
              out.collect(
                Tuple4.of(tuple.f0, tuple.f1, tuple.f2,
                  row.getField(0))); // 將單詞和詞典篩選合併, 返回 <taskId, 本分割槽句子數目, 本單詞在本句子中的idx,單詞在詞典的idx>
// runtime的變數是:
row = {Tuple3@10640}  // Tuple3<單詞在詞典的idx,單詞,單詞在詞典中對應的元素>
 f0 = {Integer@10650} 7
 f1 = "老黃"
 f2 = {Word2VecTrainBatchOp$Word@10652} 
                            
tuple = {Tuple4@10641} // (taskId, 本分割槽句子數目, 本單詞在本句子中的idx, 本單詞)
 f0 = {Integer@10642} 1
 f1 = {Long@10643} 0
 f2 = {Integer@10644} 0
 f3 = "老黃"                        
                        
                        }
          }
        }
      }).groupBy(0, 1) // 分組排序
      .reduceGroup(new GroupReduceFunction <Tuple4 <Integer, Long, Integer, Integer>, int[]>() {
        @Override
        public void reduce(Iterable <Tuple4 <Integer, Long, Integer, Integer>> values, Collector <int[]> out) {
          ArrayList <Tuple2 <Integer, Integer>> elements = new ArrayList <>();

          for (Tuple4 <Integer, Long, Integer, Integer> val : values) {
            // 得到 (本單詞在本句子中的idx, 本單詞在詞典的idx)
            elements.add(Tuple2.of(val.f2, val.f3));
          }
 
// runtime變數如下:
val = {Tuple4@10732} "(2,0,0,0)" //  <taskId, 本分割槽句子數目, 本單詞在本句子中的idx,單詞在詞典的idx>
 f0 = {Integer@10737} 2
 f1 = {Long@10738} 0
 f2 = {Integer@10733} 0
 f3 = {Integer@10733} 0  
    
elements = {ArrayList@10797}  size = 7
 0 = {Tuple2@10803} "(0,4)"
 1 = {Tuple2@10804} "(1,1)"
 2 = {Tuple2@10805} "(2,9)"
 3 = {Tuple2@10806} "(3,3)"
 4 = {Tuple2@10807} "(4,8)"
 5 = {Tuple2@10808} "(5,6)"
 6 = {Tuple2@10809} "(6,2)"                 

          Collections.sort(elements, new Comparator <Tuple2 <Integer, Integer>>() {
            @Override
            public int compare(Tuple2 <Integer, Integer> o1, Tuple2 <Integer, Integer> o2) {
              return o1.f0.compareTo(o2.f0);
            }
          });

          int[] ret = new int[elements.size()];

          for (int i = 0; i < elements.size(); ++i) {
            ret[i] = elements.get(i).f1; // 返回 "本單詞在詞典的idx"
          }

// runtime變數如下:                    
ret = {int[7]@10799} 
 0 = 4
 1 = 1
 2 = 9
 3 = 3
 4 = 8
 5 = 6
 6 = 2                    
          out.collect(ret);
        }
      });
}

這裡使用了 Flink coGroup 功能完成了雙流匹配合並功能。coGroup 和 Join 的區別是:

  • Join:Flink只輸出條件匹配的元素對 給 使用者;
  • coGroup :除了輸出匹配的元素對以外,也會輸出未能匹配的元素;

在 coGroup 的 CoGroupFunction 中,想輸出什麼形式的元素都行,完全看使用者的具體實現。

5.6 獲取精簡詞典

到了這一步,已經把每個句子都翻譯成了一個詞典idx的序列,比如:

原始輸入 : "老王 是 我們 團隊 裡 最胖 的"

編碼之後 : “4,1,9,3,8,6,2” , 這裡每個數字是 本句子中每個單詞在詞典中的序列號。

接下來Alink換了一條路,精簡詞典, 就是去掉單詞原始文字。

DataSet <Tuple2 <Integer, Word>> vocabWithoutWordStr = vocab
      .map(new UseVocabWithoutWordString());

原始詞典是 Tuple3<單詞在詞典的idx,單詞,單詞在詞典中對應的元素>

"(1,的,com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@13099fc)"

精簡之後的詞典是 Tuple2<單詞在詞典的idx,單詞在詞典中對應的元素>

"(1, com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@13099fc)"

程式碼如下:

private static class UseVocabWithoutWordString
    implements MapFunction <Tuple3 <Integer, String, Word>, Tuple2 <Integer, Word>> {
    @Override
    public Tuple2 <Integer, Word> map(Tuple3 <Integer, String, Word> value) throws Exception {
      return Tuple2.of(value.f0, value.f2); // 去掉單詞原始文字 f1
    }
}

// runtime變數如下:
value = {Tuple3@10692} "(1,的,com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@13099fc)"
 f0 = {Integer@10693} 1
  value = 1
 f1 = "的"
  value = {char[1]@10700} 
  hash = 0
 f2 = {Word2VecTrainBatchOp$Word@10694} 
  cnt = 2
  point = {int[3]@10698} 
   0 = 8
   1 = 7
   2 = 5
  code = {int[3]@10699} 
   0 = 1
   1 = 0
   2 = 1

5.7 初始化模型

用精簡後的詞典初始化模型,即隨機初始化所有的模型權重引數θ,所有的詞向量w

DataSet <Tuple2 <Integer, double[]>> initialModel = vocabWithoutWordStr
      .mapPartition(new initialModel(seed, vectorSize))
      .rebalance();

現在詞典是:Tuple2<每個單詞在詞典的idx,每個單詞在詞典中對應的元素>,這裡只用到了 idx。

最後初始化的模型是 :<每個單詞在詞典中的idx,隨機初始化的權重係數>。權重大小預設是 100。

具體程式碼是

private static class initialModel
    extends RichMapPartitionFunction <Tuple2 <Integer, Word>, Tuple2 <Integer, double[]>> {
    private final long seed;
    private final int vectorSize;
    Random random;

    public initialModel(long seed, int vectorSize) {
      this.seed = seed;
      this.vectorSize = vectorSize;
      random = new Random();
    }

    @Override
    public void open(Configuration parameters) throws Exception {
      random.setSeed(seed + getRuntimeContext().getIndexOfThisSubtask());
    }

    @Override
    public void mapPartition(Iterable <Tuple2 <Integer, Word>> values,
                 Collector <Tuple2 <Integer, double[]>> out) throws Exception {
      for (Tuple2 <Integer, Word> val : values) {
        double[] inBuf = new double[vectorSize];

        for (int i = 0; i < vectorSize; ++i) {
          inBuf[i] = random.nextFloat();
        }

        // 傳送 <每個單詞在詞典中的idx,隨機初始化的係數>
        out.collect(Tuple2.of(val.f0, inBuf));
      }
    }
}

5.8 計算迭代次數

現在計算迭代訓練的次數,就是 "訓練語料中所有單詞數目 / 100000L" 和 5 之間的最大值。

DataSet <Integer> syncNum = DataSetUtils
      .countElementsPerPartition(trainData)
      .sum(1)
      .map(new MapFunction <Tuple2 <Integer, Long>, Integer>() {
        @Override
        public Integer map(Tuple2 <Integer, Long> value) throws Exception {
          return Math.max((int) (value.f1 / 100000L), 5);
        }
      });

至此,完成了預處理節點:對輸入的處理,以及詞典、二叉樹的建立。下一步就是要開始訓練。

0xFF 參考

word2vec原理推導與程式碼分析

文字深度表示模型Word2Vec

word2vec原理(二) 基於Hierarchical Softmax的模型

word2vec原理(一) CBOW與Skip-Gram模型基礎

word2vec原理(三) 基於Negative Sampling的模型

word2vec概述

對Word2Vec的理解

自己動手寫word2vec (一):主要概念和流程

自己動手寫word2vec (二):統計詞頻

自己動手寫word2vec (三):構建Huffman樹

自己動手寫word2vec (四):CBOW和skip-gram模型

word2vec 中的數學原理詳解(一)目錄和前言

基於 Hierarchical Softmax 的模型

基於 Negative Sampling 的模型

機器學習演算法實現解析——word2vec原始碼解析

Word2Vec原始碼解析

word2vec原始碼思路和關鍵變數

Word2Vec原始碼最詳細解析(下)

word2vec原始碼思路和關鍵變數