Spark下的KMeans演算法
阿新 • • 發佈:2019-01-31
import breeze.linalg.{DenseVector, squaredDistance} import org.apache.spark.{SparkContext, SparkConf} /** * Created by fhqplzj on 16-7-16 at 下午12:21. */ object SimpleKMeans { def main(args: Array[String]): Unit = { if (args.length < 3) { System.err.println("Usage: SimpleKMeans <file> <k> <convergeDist>") System.exit(1) } val conf = new SparkConf().setMaster("local[*]").setAppName("SimpleKMeans") val sc = new SparkContext(conf) val data = sc.textFile(args(0)).map(_.split("\\s+").map(_.toDouble)).map(DenseVector(_)).cache() val k = args(1).toInt val convergeDist = args(2).toDouble val kPoints = data.takeSample(false, k, System.nanoTime()) var tempDist = Double.PositiveInfinity while (tempDist > convergeDist) { val newPoints = data.map { p => val minIdx = kPoints.map(squaredDistance(_, p)).zipWithIndex.minBy(_._1)._2 (minIdx, (p, 1)) }.reduceByKey { case (x, y) => (x._1 + y._1, x._2 + y._2) }.mapValues(x => x._1 /= x._2.toDouble).collect() tempDist = newPoints.map { case (idx, dv) => squaredDistance(kPoints(idx), dv) }.sum newPoints.foreach { case (idx, dv) => kPoints(idx) = dv } } println("Final centers:") kPoints.foreach(println) } }