K-means演算法解析及程式碼
阿新 • • 發佈:2018-12-09
上週看到K-means演算法,覺得挺有意思的,然後就分析了一下原理,又用JAVA實現了一下,水平有限,還請看到此部落格的各路大神, 如果看到有誤的地方,還請幫我糾正一下。
我給這個演算法的定義:根據某種規則,將相同的或者相近的物件,存放到一起。
基本原理:
1.定義幾個初始點當做基準點,
2.計算出當前的聚類,
3.根據新的聚類,確定下一個基準點,然後再次計算出新的聚類。
用到的數學基礎有曼哈頓聚類距離,加權平均等。
下面貼一下程式碼
package com.dhc.jstestdemo.Model; import android.util.Log; import java.util.ArrayList; import java.util.List; /** * K-means演算法解析 * Created by 大漠dreamer on 2018/11/26. */ public class KMeans { /** * 需要一個集合來存放原始座標 */ List<Point> original = null; Point point1 = null; Point point2 = null; Point point3 = null; Point point4 = null; Point point5 = null; Point point6 = null; Point point7 = null; Point point8 = null; /** * 3個新的聚類 */ List<Point> list1 = null; List<Point> list2 = null; List<Point> list3 = null; Point basePointOne = null; Point basePointTwo = null; Point basePointThree = null; /** * 聚類計算的次數 */ private static final int calculatorNumber = 2; private int calculator = 0; /** * 建構函式,可以根據自己的要求,來定製化需要進行聚類的點 * 這裡例子採用的是8個點, 3個初始化基準點的方法,最終計算十次之後來確定聚類 */ public KMeans() { } /** * 初始化 * 這裡例子採用的是8個點, 3個初始化基準點的方法,最終來確定聚類 */ public void initData() { original = new ArrayList<>(); list1 = new ArrayList<>(); list2 = new ArrayList<>(); list3 = new ArrayList<>(); point1 = new Point(1.0, 2.0); original.add(point1); point2 = new Point(4.0, 3.0); original.add(point2); point3 = new Point(3.0, 5.0); original.add(point3); point4 = new Point(4.0, 9.0); original.add(point4); point5 = new Point(2.0, 10.0); original.add(point5); point6 = new Point(6.0, 5.0); original.add(point6); point7 = new Point(5.0, 2.0); original.add(point7); point8 = new Point(7.0, 1.0); original.add(point8); //選取初始點,分別計算曼哈頓聚類距離,此處選取1,4,7為初始點 basePointOne = point1; basePointTwo = point4; basePointThree = point7; } /** * 計算點到基準點的距離,並將資料新增到對應的集合中 * * @param point */ private void setPointToCluster(Point point, Point pointBase1 , Point pointBase2, Point pointBase3) { Double distanceForOneToOne = ManHaDunDistance(point, pointBase1); Double distanceForOneToFour = ManHaDunDistance(point, pointBase2); Double distanceForOneToSeven = ManHaDunDistance(point, pointBase3); Double compareOne = Math.min(distanceForOneToOne, distanceForOneToFour); Double compareTwo = Math.min(compareOne, distanceForOneToSeven); if (compareTwo.equals(distanceForOneToOne)) { list1.add(point); } else if (compareTwo.equals(distanceForOneToFour)) { list2.add(point); } else { list3.add(point); } } /** * 計算下一個聚類, */ public void getNextBasePointAndUpdateCluster() { calculator++; /** * 當遞迴次數已經到達限制次數之後,不再進行遞迴運算,計算停止 */ if (calculator == calculatorNumber) { return; } /** * 每次計算聚類的時候,清除上一次的聚類資料 */ if (list1 != null) { list1.clear(); } if (list2 != null) { list2.clear(); } if (list3 != null) { list3.clear(); } setPointToCluster(point1, basePointOne, basePointTwo, basePointThree); setPointToCluster(point2, basePointOne, basePointTwo, basePointThree); setPointToCluster(point3, basePointOne, basePointTwo, basePointThree); setPointToCluster(point4, basePointOne, basePointTwo, basePointThree); setPointToCluster(point5, basePointOne, basePointTwo, basePointThree); setPointToCluster(point6, basePointOne, basePointTwo, basePointThree); setPointToCluster(point7, basePointOne, basePointTwo, basePointThree); setPointToCluster(point8, basePointOne, basePointTwo, basePointThree); basePointOne = new Point(getAverage(list1, true), getAverage(list1, false)); basePointTwo = new Point(getAverage(list2, true), getAverage(list2, false)); basePointThree = new Point(getAverage(list3, true), getAverage(list3, false)); /** * 遞迴繼續算下一個點和聚類 */ getNextBasePointAndUpdateCluster(); } /** * 計算數字的加權平均值 */ private Double getAverage(List<Point> list, boolean isX) { Double sum = 0.0; for (int i = 0; i < list.size(); i++) { if (isX) { sum = sum + list.get(i).getX(); } else { sum = sum + list.get(i).getY(); } } return sum / list.size(); } /** * 曼哈頓聚類距離 */ private Double ManHaDunDistance(Point pointOne, Point pointTwo) { return Math.abs(pointTwo.getX() - pointOne.getX()) + Math.abs(pointTwo.getY() - pointOne.getY()); } public void typeList() { for (int i = 0; i < list1.size(); i++) { Point point = list1.get(i); int index = original.indexOf(point); Log.d("cluster", "我來自聚類1----橫座標為:" + point.getX() + "縱座標為:" + point.getY() + "位於原始集合裡面的:" + (index + 1) + "位置"); } for (int i = 0; i < list2.size(); i++) { Point point = list2.get(i); int index = original.indexOf(point); Log.d("cluster", "我來自聚類2----橫座標為:" + point.getX() + "縱座標為:" + point.getY() + "位於原始集合裡面的:" + (index + 1) + "位置"); } for (int i = 0; i < list3.size(); i++) { Point point = list3.get(i); int index = original.indexOf(point); Log.d("cluster", "我來自聚類3----橫座標為:" + point.getX() + "縱座標為:" + point.getY() + "位於原始集合裡面的:" + (index + 1) + "位置"); } } /** * 座標類 */ class Point { Double x; Double y; public Point(Double x, Double y) { this.x = x; this.y = y; } public Double getX() { return x; } public void setX(Double x) { this.x = x; } public Double getY() { return y; } public void setY(Double y) { this.y = y; } } }
測試方法
private void clusterCal() {
KMeans kMeans = new KMeans();
kMeans.initData();
kMeans.getNextBasePointAndUpdateCluster();
kMeans.typeList();
}
執行結果 這裡需要糾正一個錯誤,就是聚類運算的次數,不是自己定義的,而是計算到一定程度,聚類基準點會不再發生變化,此時,代表計算完成。需要手動判斷,基準點是否還在發生變化。