基於spark實現kmeans的更新聚類中心的關鍵程式碼
阿新 • • 發佈:2019-02-18
/**
* 這個方法作用是返回被給樣本的所屬聚類中心的索引,及其到這個聚類中心的距離,
* 其中class VectorWithNorm(val vector: Vector, val norm: Double)這是VectorWithNorm的定義,vector為樣本構成的向量形式,norm是p-norm值
*/
private[mllib] def findClosest(
centers: TraversableOnce[VectorWithNorm],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
var i = 0
centers.foreach { center =>
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
// distance computation.
var lowerBoundOfSqDist = center.norm - point.norm
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
if (lowerBoundOfSqDist < bestDistance) {
val distance: Double = fastSquaredDistance(center, point)
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
}
}
i += 1
}
(bestIndex, bestDistance)
}
/**
* 下面是實現kmeans演算法程式碼
* instr引數解釋一下這個在mllib裡面預設設定為None,之所以存在這個引數是因為ML庫中呼叫到了,因此傳入這個引數其實是適應ML裡面pipline的思想 ,Instrumentation是一個封裝器,記錄某個訓練session期間有用的資訊,class Instrumentation[E <: Estimator[_]] ,同時預設的初始點選擇方法非隨機,而是使用kmean++的方法進行選擇。
* VectorWithNorm(val vector: Vector, val norm: Double)
*/
private def runAlgorithm(
data: RDD[VectorWithNorm],
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
val sc = data.sparkContext
val initStartTime = System.nanoTime()
//初始化聚類中心
val centers = initialModel match {
case Some(kMeansCenters) =>
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")
var converged = false
var cost = 0.0
var iteration = 0
val iterationStartTime = System.nanoTime()
instr.foreach(_.logNumFeatures(centers.head.vector.size))
// Execute iterations of Lloyd's algorithm until converged 累加器使用 廣播聚類中心變數
while (iteration < maxIterations && !converged) {
val costAccum = sc.doubleAccumulator
val bcCenters = sc.broadcast(centers)
// 找到每個聚類中心擁有的樣本數量 dims為向量即樣本的特徵維度,此處採用了mapPartitions 會對每個分割槽中的元素進行相同操作每個分割槽包含多個樣本即下面程式碼中的points ,不同於map是對RDD中每個元素進行操作
val totalContribs = data.mapPartitions { points =>
//bcCenters.value得到Array[VectorWithNorm]即獲得廣播變數的值
val thisCenters = bcCenters.value
//VectorWithNorm(val vector: Vector, val norm: Double);thisCenters為Array[VectorWithNorm]
val dims = thisCenters.head.vector.size
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
//建立了一個數組存放每個聚類中心對應的樣本數目
val counts = Array.fill(thisCenters.length)(0L)
//在每個分割槽中對每個樣本進行下面操作,尋找其最優的所屬中心bestCenter,及其距離cost,然後對應的counts陣列對應的bestCenter索引位置+1,該聚類中心樣本數多1
points.foreach { point =>
val (bestCenter, cost) = KMeans.findClosest(thisCenters, point)
costAccum.add(cost)
val sum = sums(bestCenter)
axpy(1.0, point.vector, sum)
counts(bestCenter) += 1
}
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1) //sum1 += 1.0 * sum2
(sum1, count1 + count2)
}.collectAsMap()
//collectAsMap()後得到了一個totalContribs變數格式為[key,(vector,long)]的map 其中key為聚類中心索引 long為該聚類中心樣本數目,sum1 sum2均為向量,此處應該是通過reducebykey將叢集中相同key的進行合併最終統計出每個聚類中心索引key對應的鍵值對形式的記錄key,(vector,long)] 其中key為聚類中心索引,vector為該聚類中心所對應的所有樣本每一個維度的和組成的向量,long代表了該聚類中心樣本數目
bcCenters.destroy(blocking = false)
// Update the cluster centers and costs 每個聚類中心儲存了其樣本每一維度的和及其樣本個數 用來更新聚類中心 裡面的sum為聚類中心
converged = true
totalContribs.foreach { case (j, (sum, count)) =>
//scal函式作用 sum*= (1.0 / count )*sum sum是一個向量將向量的每一維度乘以樣本個數得到每一維度平均值,即新的聚類中心
scal(1.0 / count, sum)
val newCenter = new VectorWithNorm(sum)
if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) {
converged = false
}
centers(j) = newCenter
}
cost = costAccum.value
iteration += 1
}
val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")
if (iteration == maxIterations) {
logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
} else {
logInfo(s"KMeans converged in $iteration iterations.")
}
logInfo(s"The cost is $cost.")
new KMeansModel(centers.map(_.vector))
}