Spark K-Means 演算法例子
阿新 • • 發佈:2019-02-07
k-means演算法是以空間的點距離為基準,隨機或者按照一定規則選擇幾個中心點資料,計算每個點到該幾個中心點的距離,按照距離值最近歸為一類的原則,把空間所有的點歸為初始化的幾個中心,稱之為中心簇。
然後,找到每個中心簇的中心,再次計算空間所有的點到新的中心點的舉例並歸類,以此不斷迭代,直到達到迭代次數或者點中心不再變化為止。
kmeans_data.txt中的資料
0.0 0.0 0.0
0.1 0.1 0.1
0.2 0.2 0.2
9.0 9.0 9.0
9.1 9.1 9.1
9.2 9.2 9.2
package spark;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
// $example on$
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
public class JavaKMeansExample {
public static void main(String[] args) {
Logger logger = Logger.getLogger(JavaKMeansExample.class);
// 設定日誌的等級 並關閉jetty容器的日誌
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.apache.eclipse.jetty.server" ).setLevel(Level.OFF);
SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("JavaKMeansExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
String path = "F:/spark-2.1.0-bin-hadoop2.6/data/mllib/kmeans_data.txt";
JavaRDD<String> data = jsc.textFile(path);
JavaRDD<Vector> parsedData = data.map(s -> {
String[] sarray = s.split(" ");
double[] values = new double[sarray.length];
for (int i = 0; i < sarray.length; i++) {
values[i] = Double.parseDouble(sarray[i]);
}
return Vectors.dense(values);
});
parsedData.cache();
int numClusters = 2;
int numIterations = 20;
int runs = 10;
/**
* KMeans.train(RDD<Vector> data, int k, int maxIterations, int runs, String initializationMode, long seed) data 進行聚類的資料 k
* 初始的中心點個數 maxIterations 迭代次數
* runs 執行次數
* initializationMode 初始中心點的選擇方式, 目前支援隨機選 "random" or "k-means||"。預設是 K-means||
* seed 叢集初始化時的隨機種子。
*/
KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations, runs);
// 輸出聚類的中心
System.out.println("Cluster centers:");
for (Vector center : clusters.clusterCenters()) {
System.out.println(" " + center);
}
// 本次聚類操作的收斂性,此值越低越好
double cost = clusters.computeCost(parsedData.rdd());
System.out.println("Cost: " + cost);
double WSSSE = clusters.computeCost(parsedData.rdd());
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
// 預測並輸出輸出每組資料對應的中心
parsedData.foreach(f -> {
System.out.print(f + "\n");
System.out.println(clusters.predict(f));
});
// 預測資料屬於哪個中心點
int centerIndex = clusters.predict(Vectors.dense(new double[] {3.6, 4.7, 7.1}));//中心點的索引
System.out.println("預測資料 (3.6, 4.7, 7.1)屬於中心[" + centerIndex + "]:" + clusters.clusterCenters()[centerIndex]);
centerIndex = clusters.predict(Vectors.dense(new double[] {1.1, 0.7, 0.3}));
System.out.println("預測資料 (1.1,0.7, 0.3)屬於中心[" + centerIndex + "]:" + clusters.clusterCenters()[centerIndex]);
jsc.stop();
}
}
...
參考