1. 程式人生 > >Spark K-Means 演算法例子

Spark K-Means 演算法例子

k-means演算法是以空間的點距離為基準,隨機或者按照一定規則選擇幾個中心點資料,計算每個點到該幾個中心點的距離,按照距離值最近歸為一類的原則,把空間所有的點歸為初始化的幾個中心,稱之為中心簇。
然後,找到每個中心簇的中心,再次計算空間所有的點到新的中心點的舉例並歸類,以此不斷迭代,直到達到迭代次數或者點中心不再變化為止。

kmeans_data.txt中的資料

0.0 0.0 0.0
0.1 0.1 0.1
0.2 0.2 0.2
9.0 9.0 9.0
9.1 9.1 9.1
9.2 9.2 9.2

package spark;

import org.apache.log4j.Level;
import
org.apache.log4j.Logger; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; // $example on$ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.clustering.KMeans; import org.apache.spark.mllib.clustering.KMeansModel; import org.apache.spark.mllib.linalg.Vector; import
org.apache.spark.mllib.linalg.Vectors; public class JavaKMeansExample { public static void main(String[] args) { Logger logger = Logger.getLogger(JavaKMeansExample.class); // 設定日誌的等級 並關閉jetty容器的日誌 Logger.getLogger("org.apache.spark").setLevel(Level.WARN); Logger.getLogger("org.apache.eclipse.jetty.server"
).setLevel(Level.OFF); SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("JavaKMeansExample"); JavaSparkContext jsc = new JavaSparkContext(conf); String path = "F:/spark-2.1.0-bin-hadoop2.6/data/mllib/kmeans_data.txt"; JavaRDD<String> data = jsc.textFile(path); JavaRDD<Vector> parsedData = data.map(s -> { String[] sarray = s.split(" "); double[] values = new double[sarray.length]; for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); } return Vectors.dense(values); }); parsedData.cache(); int numClusters = 2; int numIterations = 20; int runs = 10; /** * KMeans.train(RDD<Vector> data, int k, int maxIterations, int runs, String initializationMode, long seed) data 進行聚類的資料 k * 初始的中心點個數 maxIterations 迭代次數 * runs 執行次數 * initializationMode 初始中心點的選擇方式, 目前支援隨機選 "random" or "k-means||"。預設是 K-means|| * seed 叢集初始化時的隨機種子。 */ KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations, runs); // 輸出聚類的中心 System.out.println("Cluster centers:"); for (Vector center : clusters.clusterCenters()) { System.out.println(" " + center); } // 本次聚類操作的收斂性,此值越低越好 double cost = clusters.computeCost(parsedData.rdd()); System.out.println("Cost: " + cost); double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); // 預測並輸出輸出每組資料對應的中心 parsedData.foreach(f -> { System.out.print(f + "\n"); System.out.println(clusters.predict(f)); }); // 預測資料屬於哪個中心點 int centerIndex = clusters.predict(Vectors.dense(new double[] {3.6, 4.7, 7.1}));//中心點的索引 System.out.println("預測資料 (3.6, 4.7, 7.1)屬於中心[" + centerIndex + "]:" + clusters.clusterCenters()[centerIndex]); centerIndex = clusters.predict(Vectors.dense(new double[] {1.1, 0.7, 0.3})); System.out.println("預測資料 (1.1,0.7, 0.3)屬於中心[" + centerIndex + "]:" + clusters.clusterCenters()[centerIndex]); jsc.stop(); } } ...

參考