Spark MLlib原始碼分析—Word2Vec原始碼詳解
阿新 • • 發佈:2019-01-05
以下程式碼是我依據SparkMLlib(版本1.6)中Word2Vec原始碼改寫而來,基本算是照搬。此版Word2Vec是基於Hierarchical Softmax的Skip-gram模型的實現。
在決定讀懂原始碼前,博主建議讀者先看一下《Word2Vec_中的數學原理詳解》或者看本人根據這篇文件做的一個摘要總結:
http://blog.csdn.net/liuyuemaicha/article/details/52611219
Ps* 程式碼註解的很詳細了,閱讀程式碼請從類CWord2Vec的fit函式開始
import java.nio.ByteBuffer
import java.util.{Random => JavaRandom}
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import scala.collection.mutable
import scala.util.hashing.MurmurHash3
/**
* Entry in vocabulary
*/
private case class VocabWord(
var word: String, //分詞
var cn: Int,//計數
var point: Array[Int], //儲存路徑,即經過得結點
var code: Array[Int], //記錄Huffman編碼
var codeLen: Int ////儲存到達該葉子結點,要經過多少個結點
)
class CWord2Vec extends Serializable{
private val random = new JavaRandom()
private var seed = new JavaRandom().nextLong()
private var vectorSize = 100 //向量大小
private var learningRate = 0.025 //學習率
private var numPartitions = 1
private var numIterations = 60 //迭代次數
private var minCount = 5 //關鍵詞的上下視窗
private var maxSentenceLength = 1000 //每條語句以長度maxSentenceLength分組
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
/** context words from [-window, window] */
private var window = 5
private var trainWordsCount = 0L
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
private var vocabHash = mutable.HashMap.empty[String, Int]
/* 詞典構建 */
private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {
val words = dataset.flatMap(x => x)
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _) //分詞計數
.filter(_._2 >= minCount)//過濾頻數少於minCount的分詞
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
.collect()
.sortWith((a, b) => a.cn > b.cn) //按頻數從大到小排序
vocabSize = vocab.length //詞典的元素個數
require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
"the setting of minCount, which could be large enough to remove all your words in sentences.")
var a = 0
while (a < vocabSize) {
vocabHash += vocab(a).word -> a //生成hashMap(K:word,V:a)--> 對詞典中所有元素進行對映,方便查詢
trainWordsCount += vocab(a).cn //計算語料C中分詞的數量
a += 1
}
//logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
}
/* Create Huffman Tree */
private def createBinaryTree(): Unit = {
val count = new Array[Long](vocabSize * 2 + 1) //二叉樹中所有的結點
val binary = new Array[Int](vocabSize * 2 + 1)//設定每個結點的Huffman編碼:左1,右0
val parentNode = new Array[Int](vocabSize * 2 + 1)//儲存每個結點的父節點
val code = new Array[Int](MAX_CODE_LENGTH)//儲存每個葉子結點的Huffman編碼
val point = new Array[Int](MAX_CODE_LENGTH)//儲存每個葉子結點的路徑(經歷過哪些結點)
var a = 0
while (a < vocabSize) {
count(a) = vocab(a).cn //初始化葉子結點,以頻數作為權值,葉子:0~vocabSize-1
a += 1
}
while (a < 2 * vocabSize) {
count(a) = 1e9.toInt //10的9次方,非葉子結點,初始化為最大值
a += 1
}
var pos1 = vocabSize - 1
var pos2 = vocabSize
var min1i = 0
var min2i = 0
a = 0
while (a < vocabSize - 1) { //構造Huffman樹
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min1i = pos1
pos1 -= 1
} else {
min1i = pos2
pos2 += 1
}
} else {
min1i = pos2
pos2 += 1
}
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min2i = pos1
pos1 -= 1
} else {
min2i = pos2
pos2 += 1
}
} else {
min2i = pos2
pos2 += 1
}
count(vocabSize + a) = count(min1i) + count(min2i)
parentNode(min1i) = vocabSize + a
parentNode(min2i) = vocabSize + a
binary(min2i) = 1
a += 1
}
// Now assign binary code to each vocabulary word
var i = 0
a = 0
while (a < vocabSize) {
var b = a
i = 0
while (b != vocabSize * 2 - 2) { //vocabSize * 2 - 2 表示根結點
code(i) = binary(b) //第b個結點的Huffman編碼是0 or 1
point(i) = b //儲存路徑,經過b結點
i += 1
b = parentNode(b)
}
vocab(a).codeLen = i //儲存到達葉子結點a,要經過多少個結點
vocab(a).point(0) = vocabSize - 2
b = 0
while (b < i) {
vocab(a).code(i - b - 1) = code(b) ////記錄Huffman編碼
vocab(a).point(i - b) = point(b) - vocabSize //記錄經過的結點
b += 1
}
a += 1
}
}
//建立sigmoid函式查詢表
private def createExpTable(): Array[Float] = { //初始化ExpTable,初始化引數為0-999的e值
val expTable = new Array[Float](EXP_TABLE_SIZE)
var i = 0
while (i < EXP_TABLE_SIZE) {
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
expTable(i) = (tmp / (tmp + 1.0)).toFloat
i += 1
}
expTable
}
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
learnVocab(dataset) //構建詞典
createBinaryTree() //構建 Huffman 樹
val sc = dataset.context
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
// Each sentence will map to 0 or more Array[Int]
sentenceIter.flatMap { sentence =>
val wordIndexes = sentence.flatMap(bcVocabHash.value.get)// 將分詞轉化為對應的目錄值(index)
wordIndexes.grouped(maxSentenceLength).map(_.toArray) //一條語句長度大於1000後,將被拆分為多個分組
}
}
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
if (vocabSize.toLong * vectorSize >= Int.MaxValue) {
throw new RuntimeException("vocabSize.toLong * vectorSize >= Int.MaxValue, " +
"Int.MaxValue: " + Int.MaxValue)
}
//初始化葉子節點,分詞向量隨機設定初始值
val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
//初始化非葉子結點,引數向量設定初始值為0
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = learningRate //學習率
for (k <- 1 to numIterations){ //對整個語料開始迭代,總共完成numIterations次迭代
val bcSyn0Global = sc.broadcast(syn0Global)
val bcSyn1Global = sc.broadcast(syn1Global)
//對每條句子進行向量計算:case中idx表示分詞的目錄,iter表示這條句子的起始地址
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
// TODO: discount by iteration?
alpha =
learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
//logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.length
var pos = 0
while (pos < sentence.length) {
val word = sentence(pos) //這條句子中第pos個分詞
//在window範圍內隨機取出一個詞b window 表示中心詞w上下最大各window個詞。
// 則最多一共2*window個詞,即Context(w)的長度最大為2*window
val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {//此處迴圈是以pos為中心的skip-gram,即Context(w)中分詞的向量計算
if (a != window) {
val c = pos - window + a //c 是以 pos 為中心,所要表徵Context(w)中的一個分詞
if (c >= 0 && c < sentence.length) {
val lastWord = sentence(c) //c是通過pos詞得到的,即Huffman樹的葉子結點,也就是lastWord
val l1 = lastWord * vectorSize
val neu1e = new Array[Float](vectorSize) //用來儲存Context(w)中各分詞向量對分詞w的貢獻向量值
// Hierarchical softmax
var d = 0
//Huffman樹中到達單詞word,要經過結點數為 codeLen,這裡從根節點開始遍歷Huffman樹
while (d < bcVocab.value(word).codeLen) {
val inner = bcVocab.value(word).point(d) //經過第d步時的結點
val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)//syn0 * syn1 兩向量相乘
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) //neu1e = g * syn1 + neu1e
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) //syn1 = g * syn0 + syn1
syn1Modify(inner) += 1
}
d += 1
}
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) //syn0 = 1.0f * neu1e + syn0
syn0Modify(lastWord) += 1
}
}
a += 1
}
pos += 1
}
(syn0, syn1, lwc, wc)
}
val syn0Local = model._1 //syn0 為葉子結點向量,即分詞向量
val syn1Local = model._2 //syn1 為非葉子結點向量,即引數向量
// Only output modified vectors.
Iterator.tabulate(vocabSize) { index =>
if (syn0Modify(index) > 0) {
Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
} else {
None
}
}.flatten ++ Iterator.tabulate(vocabSize) { index =>
if (syn1Modify(index) > 0) {
Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
} else {
None
}
}.flatten
}
//處理完每條句子的向量後,對所有語句中相同分詞所對應的向量相加
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) //v2 + v1
v1
}.collect()
var i = 0
while (i < synAgg.length) {
val index = synAgg(i)._1
if (index < vocabSize) {
Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
} else {
Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
}
i += 1
}
bcSyn0Global.unpersist(false)
bcSyn1Global.unpersist(false)
}
newSentences.unpersist()
expTable.unpersist()
bcVocab.unpersist()
bcVocabHash.unpersist()
val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
}
}
class Word2VecModel (
val wordIndex: Map[String, Int],
val wordVectors: Array[Float]) extends Serializable
{
private val numWords = wordIndex.size
private val vectorSize = wordVectors.length / numWords
private val wordList: Array[String] = {
val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
wl.toArray
}
private val wordVecNorms: Array[Double] = {
val wordVecNorms = new Array[Double](numWords)
var i = 0
while (i < numWords) {
val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
i += 1
}
wordVecNorms
}
def transform(word: String): Vector = {
wordIndex.get(word) match {
case Some(ind) =>
val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
Vectors.dense(vec.map(_.toDouble))
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
}
}
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
findSynonyms(vector, num)
}
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
val cosineVec = Array.fill[Float](numWords)(0)
val alpha: Float = 1
val beta: Float = 0
// Normalize input vector before blas.sgemv to avoid Inf value
val vecNorm = blas.snrm2(vectorSize, fVector, 1)
if (vecNorm != 0.0f) {
blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1)
}
blas.sgemv(
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
val cosVec = cosineVec.map(_.toDouble)
var ind = 0
while (ind < numWords) {
val norm = wordVecNorms(ind)
if (norm == 0.0) {
cosVec(ind) = 0.0
} else {
cosVec(ind) /= norm
}
ind += 1
}
wordList.zip(cosVec)
.toSeq
.sortBy(-_._2)
.take(num + 1)
.tail
.toArray
}
}
private class XORShiftRandom(init: Long) extends JavaRandom(init) {
private var seed = hashSeed(init)
private def hashSeed(seed: Long): Long = {
val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array()
val lowBits = MurmurHash3.bytesHash(bytes)
val highBits = MurmurHash3.bytesHash(bytes, lowBits)
(highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL)
}
// we need to just override next - this will be called by nextInt, nextDouble,
// nextGaussian, nextLong, etc.
override protected def next(bits: Int): Int = {
var nextSeed = seed ^ (seed << 21)
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
}
}