使用Java實現K-Means聚類演算法
阿新 • • 發佈:2018-11-22
第一次寫部落格,隨便寫寫。
關於K-Means介紹很多,還不清楚可以查一些相關資料。
個人對其實現步驟簡單總結為4步:
1.選出k值,隨機出k個起始質心點。
2.分別計算每個點和k個起始質點之間的距離,就近歸類。
3.最終中心點集可以劃分為k類,分別計算每類中新的中心點。
4.重複2,3步驟對所有點進行歸類,如果當所有分類的質心點不再改變,則最終收斂。
下面貼程式碼。
1.入口類,基本讀取資料來源進行訓練然後輸出。 資料來源檔案和原始碼後面會補上。
package com.hyr.kmeans; import au.com.bytecode.opencsv.CSVReader; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.List; public class KmeansMain { public static void main(String[] args) throws IOException { // 讀取資料來源檔案 CSVReader reader = new CSVReader(new FileReader("src/main/resources/data.csv")); // 資料來源 FileWriter writer = new FileWriter("src/main/resources/out.csv"); List<String[]> myEntries = reader.readAll(); // 6.8, 12.6 // 轉換資料點集 List<Point> points = new ArrayList<Point>(); // 資料點集 for (String[] entry : myEntries) { points.add(new Point(Float.parseFloat(entry[0]), Float.parseFloat(entry[1]))); } int k = 6; // K值 int type = 1; KmeansModel model = Kmeans.run(points, k, type); writer.write("==================== K is " + model.getK() + " , Object Funcion Value is " + model.getOfv() + " , calc_distance_type is " + model.getCalc_distance_type() + " ====================\n"); int i = 0; for (Cluster cluster : model.getClusters()) { i++; writer.write("==================== classification " + i + " ====================\n"); for (Point point : cluster.getPoints()) { writer.write(point.toString() + "\n"); } writer.write("\n"); writer.write("centroid is " + cluster.getCentroid().toString()); writer.write("\n\n"); } writer.close(); } }
2.最終生成的模型類,也就是最終訓練好的結果。K值,計算的點距離型別以及object function value值。
package com.hyr.kmeans; import java.util.ArrayList; import java.util.List; public class KmeansModel { private List<Cluster> clusters = new ArrayList<Cluster>(); private Double ofv; private int k; // k值 private int calc_distance_type; public KmeansModel(List<Cluster> clusters, Double ofv, int k, int calc_distance_type) { this.clusters = clusters; this.ofv = ofv; this.k = k; this.calc_distance_type = calc_distance_type; } public List<Cluster> getClusters() { return clusters; } public Double getOfv() { return ofv; } public int getK() { return k; } public int getCalc_distance_type() { return calc_distance_type; } }
3.資料集點物件,包含點的維度,程式碼裡只給出了x軸,y軸二維。以及點的距離計算。通過型別選擇距離公式。給出了幾種常用的距離公式。
package com.hyr.kmeans; public class Point { private Float x; // x 軸 private Float y; // y 軸 public Point(Float x, Float y) { this.x = x; this.y = y; } public Float getX() { return x; } public void setX(Float x) { this.x = x; } public Float getY() { return y; } public void setY(Float y) { this.y = y; } @Override public String toString() { return "Point{" + "x=" + x + ", y=" + y + '}'; } /** * 計算距離 * * @param centroid 質心點 * @param type * @return */ public Double calculateDistance(Point centroid, int type) { // TODO Double result = null; switch (type) { case 1: result = calcL1Distance(centroid); break; case 2: result = calcCanberraDistance(centroid); break; case 3: result = calcEuclidianDistance(centroid); break; } return result; } /* 計算距離公式 */ private Double calcL1Distance(Point centroid) { double res = 0; res = Math.abs(getX() - centroid.getX()) + Math.abs(getY() - centroid.getY()); return res / (double) 2; } private double calcEuclidianDistance(Point centroid) { return Math.sqrt(Math.pow((centroid.getX() - getX()), 2) + Math.pow((centroid.getY() - getY()), 2)); } private double calcCanberraDistance(Point centroid) { double res = 0; res = Math.abs(getX() - centroid.getX()) / (Math.abs(getX()) + Math.abs(centroid.getX())) + Math.abs(getY() - centroid.getY()) / (Math.abs(getY()) + Math.abs(centroid.getY())); return res / (double) 2; } @Override public boolean equals(Object obj) { Point other = (Point) obj; if (getX().equals(other.getX()) && getY().equals(other.getY())) { return true; } return false; } }
4.訓練後最終得到的分類。包含該分類的質點,屬於該分類的點集合該分類是否收斂。
package com.hyr.kmeans;
import java.util.ArrayList;
import java.util.List;
public class Cluster {
private List<Point> points = new ArrayList<Point>(); // 屬於該分類的點集
private Point centroid; // 該分類的中心質點
private boolean isConvergence = false;
public Point getCentroid() {
return centroid;
}
public void setCentroid(Point centroid) {
this.centroid = centroid;
}
@Override
public String toString() {
return centroid.toString();
}
public List<Point> getPoints() {
return points;
}
public void setPoints(List<Point> points) {
this.points = points;
}
public void initPoint() {
points.clear();
}
public boolean isConvergence() {
return isConvergence;
}
public void setConvergence(boolean convergence) {
isConvergence = convergence;
}
}
5.K-Meams訓練類。按照上面所說四個步驟不斷進行訓練。
package com.hyr.kmeans;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class Kmeans {
/**
* kmeans
*
* @param points 資料集
* @param k K值
* @param k 計算距離方式
*/
public static KmeansModel run(List<Point> points, int k, int type) {
// 初始化質心點
List<Cluster> clusters = initCentroides(points, k);
while (!checkConvergence(clusters)) { // 所有分類是否全部收斂
// 1.計算距離對每個點進行分類
// 2.判斷質心點是否改變,未改變則該分類已經收斂
// 3.重新生成質心點
initClusters(clusters); // 重置分類中的點
classifyPoint(points, clusters, type);// 計算距離進行分類
recalcularCentroides(clusters); // 重新計算質心點
}
// 計算目標函式值
Double ofv = calcularObjetiFuncionValue(clusters);
KmeansModel kmeansModel = new KmeansModel(clusters, ofv, k, type);
return kmeansModel;
}
/**
* 初始化k個質心點
*
* @param points 點集
* @param k K值
* @return 分類集合物件
*/
private static List<Cluster> initCentroides(List<Point> points, Integer k) {
List<Cluster> centroides = new ArrayList<Cluster>();
// 求出資料集的範圍(找出所有點的x最小、最大和y最小、最大座標。)
Float max_X = Float.NEGATIVE_INFINITY;
Float max_Y = Float.NEGATIVE_INFINITY;
Float min_X = Float.POSITIVE_INFINITY;
Float min_Y = Float.POSITIVE_INFINITY;
for (Point point : points) {
max_X = max_X < point.getX() ? point.getX() : max_X;
max_Y = max_Y < point.getY() ? point.getY() : max_Y;
min_X = min_X > point.getX() ? point.getX() : min_X;
min_Y = min_Y > point.getY() ? point.getY() : min_Y;
}
System.out.println("min_X" + min_X + ",max_X:" + max_X + ",min_Y" + min_Y + ",max_Y" + max_Y);
// 在範圍內隨機初始化k個質心點
Random random = new Random();
// 隨機初始化k箇中心點
for (int i = 0; i < k; i++) {
float x = random.nextFloat() * (max_X - min_X) + min_X;
float y = random.nextFloat() * (max_Y - min_Y) + min_X;
Cluster c = new Cluster();
Point centroide = new Point(x, y); // 初始化的隨機中心點
c.setCentroid(centroide);
centroides.add(c);
}
return centroides;
}
/**
* 重新計算質心點
*
* @param clusters
*/
private static void recalcularCentroides(List<Cluster> clusters) {
for (Cluster c : clusters) {
if (c.getPoints().isEmpty()) {
c.setConvergence(true);
continue;
}
// 求均值,作為新的質心點
Float x;
Float y;
Float sum_x = 0f;
Float sum_y = 0f;
for (Point point : c.getPoints()) {
sum_x += point.getX();
sum_y += point.getY();
}
x = sum_x / c.getPoints().size();
y = sum_y / c.getPoints().size();
Point nuevoCentroide = new Point(x, y); // 新的質心點
if (nuevoCentroide.equals(c.getCentroid())) { // 如果質心點不再改變 則該分類已經收斂
c.setConvergence(true);
} else {
c.setCentroid(nuevoCentroide);
}
}
}
/**
* 計算距離,對點集進行分類
*
* @param points 點集
* @param clusters 分類
* @param type 計算距離方式
*/
private static void classifyPoint(List<Point> points, List<Cluster> clusters, int type) {
for (Point point : points) {
Cluster masCercano = clusters.get(0); // 該點計算距離後所屬的分類
Double minDistancia = Double.MAX_VALUE; // 最小距離
for (Cluster cluster : clusters) {
Double distancia = point.calculateDistance(cluster.getCentroid(), type); // 點和每個分類質心點的距離
if (minDistancia > distancia) { // 得到該點和k個質心點最小的距離
minDistancia = distancia;
masCercano = cluster; // 得到該點的分類
}
}
masCercano.getPoints().add(point); // 將該點新增到距離最近的分類中
}
}
private static void initClusters(List<Cluster> clusters) {
for (Cluster cluster : clusters) {
cluster.initPoint();
}
}
/**
* 檢查收斂
*
* @param clusters
* @return
*/
private static boolean checkConvergence(List<Cluster> clusters) {
for (Cluster cluster : clusters) {
if (!cluster.isConvergence()) {
return false;
}
}
return true;
}
/**
* 計算目標函式值
*
* @param clusters
* @return
*/
private static Double calcularObjetiFuncionValue(List<Cluster> clusters) {
Double ofv = 0d;
for (Cluster cluster : clusters) {
for (Point point : cluster.getPoints()) {
int type = 1;
ofv += point.calculateDistance(cluster.getCentroid(), type);
}
}
return ofv;
}
}
最終訓練結果:
==================== K is 6 , Object Funcion Value is 21.82857036590576 , calc_distance_type is 3 ====================
==================== classification 1 ====================
Point{x=3.5, y=12.5}
centroid is Point{x=3.5, y=12.5}
==================== classification 2 ====================
Point{x=6.8, y=12.6}
Point{x=7.8, y=12.2}
Point{x=8.2, y=11.1}
Point{x=9.6, y=11.1}
centroid is Point{x=8.1, y=11.75}
==================== classification 3 ====================
Point{x=4.4, y=6.5}
Point{x=4.8, y=1.1}
Point{x=5.3, y=6.4}
Point{x=6.6, y=7.7}
Point{x=8.2, y=4.5}
Point{x=8.4, y=6.9}
Point{x=9.0, y=3.4}
centroid is Point{x=6.671428, y=5.2142863}
==================== classification 4 ====================
Point{x=6.0, y=19.9}
Point{x=6.2, y=18.5}
Point{x=5.3, y=19.4}
Point{x=7.6, y=17.4}
centroid is Point{x=6.275, y=18.800001}
==================== classification 5 ====================
Point{x=0.8, y=9.8}
Point{x=1.2, y=11.6}
Point{x=2.8, y=9.6}
Point{x=3.8, y=9.9}
centroid is Point{x=2.15, y=10.225}
==================== classification 6 ====================
Point{x=6.1, y=14.3}
centroid is Point{x=6.1, y=14.3}
程式碼下載地址:
http://download.csdn.net/download/huangyueranbbc/10267041
github:
https://github.com/huangyueranbbc/KmeansDemo