1. 程式人生 > >K-means演算法解析及程式碼

K-means演算法解析及程式碼

上週看到K-means演算法,覺得挺有意思的,然後就分析了一下原理,又用JAVA實現了一下,水平有限,還請看到此部落格的各路大神, 如果看到有誤的地方,還請幫我糾正一下。

我給這個演算法的定義:根據某種規則,將相同的或者相近的物件,存放到一起。

基本原理:

1.定義幾個初始點當做基準點,

2.計算出當前的聚類,

3.根據新的聚類,確定下一個基準點,然後再次計算出新的聚類。

用到的數學基礎有曼哈頓聚類距離,加權平均等。

下面貼一下程式碼

package com.dhc.jstestdemo.Model;

import android.util.Log;

import java.util.ArrayList;
import java.util.List;

/**
 * K-means演算法解析
 * Created by 大漠dreamer on 2018/11/26.
 */

public class KMeans {

    /**
     * 需要一個集合來存放原始座標
     */
    List<Point> original = null;
    Point point1 = null;
    Point point2 = null;
    Point point3 = null;
    Point point4 = null;

    Point point5 = null;
    Point point6 = null;
    Point point7 = null;
    Point point8 = null;

    /**
     * 3個新的聚類
     */
    List<Point> list1 = null;
    List<Point> list2 = null;
    List<Point> list3 = null;

    Point basePointOne = null;
    Point basePointTwo = null;
    Point basePointThree = null;

    /**
     * 聚類計算的次數
     */
    private static final int calculatorNumber = 2;
    private int calculator = 0;

    /**
     * 建構函式,可以根據自己的要求,來定製化需要進行聚類的點
     * 這裡例子採用的是8個點, 3個初始化基準點的方法,最終計算十次之後來確定聚類
     */
    public KMeans() {
    }

    /**
     * 初始化
     * 這裡例子採用的是8個點, 3個初始化基準點的方法,最終來確定聚類
     */
    public void initData() {
        original = new ArrayList<>();
        list1 = new ArrayList<>();
        list2 = new ArrayList<>();
        list3 = new ArrayList<>();


        point1 = new Point(1.0, 2.0);
        original.add(point1);
        point2 = new Point(4.0, 3.0);
        original.add(point2);
        point3 = new Point(3.0, 5.0);
        original.add(point3);
        point4 = new Point(4.0, 9.0);
        original.add(point4);

        point5 = new Point(2.0, 10.0);
        original.add(point5);
        point6 = new Point(6.0, 5.0);
        original.add(point6);
        point7 = new Point(5.0, 2.0);
        original.add(point7);
        point8 = new Point(7.0, 1.0);
        original.add(point8);

        //選取初始點,分別計算曼哈頓聚類距離,此處選取1,4,7為初始點
        basePointOne = point1;
        basePointTwo = point4;
        basePointThree = point7;
    }


    /**
     * 計算點到基準點的距離,並將資料新增到對應的集合中
     *
     * @param point
     */
    private void setPointToCluster(Point point, Point pointBase1
            , Point pointBase2, Point pointBase3) {

        Double distanceForOneToOne = ManHaDunDistance(point, pointBase1);
        Double distanceForOneToFour = ManHaDunDistance(point, pointBase2);
        Double distanceForOneToSeven = ManHaDunDistance(point, pointBase3);

        Double compareOne = Math.min(distanceForOneToOne, distanceForOneToFour);
        Double compareTwo = Math.min(compareOne, distanceForOneToSeven);
        if (compareTwo.equals(distanceForOneToOne)) {
            list1.add(point);
        } else if (compareTwo.equals(distanceForOneToFour)) {
            list2.add(point);
        } else {
            list3.add(point);
        }
    }

    /**
     * 計算下一個聚類,
     */
    public void getNextBasePointAndUpdateCluster() {


        calculator++;
        /**
         * 當遞迴次數已經到達限制次數之後,不再進行遞迴運算,計算停止
         */
        if (calculator == calculatorNumber) {
            return;
        }
        /**
         *  每次計算聚類的時候,清除上一次的聚類資料
         */
        if (list1 != null) {
            list1.clear();
        }
        if (list2 != null) {
            list2.clear();
        }
        if (list3 != null) {
            list3.clear();
        }

        setPointToCluster(point1, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point2, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point3, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point4, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point5, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point6, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point7, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point8, basePointOne, basePointTwo, basePointThree);

        basePointOne = new Point(getAverage(list1, true), getAverage(list1, false));
        basePointTwo = new Point(getAverage(list2, true), getAverage(list2, false));
        basePointThree = new Point(getAverage(list3, true), getAverage(list3, false));

        /**
         * 遞迴繼續算下一個點和聚類
         */
        getNextBasePointAndUpdateCluster();

    }

    /**
     * 計算數字的加權平均值
     */
    private Double getAverage(List<Point> list, boolean isX) {

        Double sum = 0.0;

        for (int i = 0; i < list.size(); i++) {
            if (isX) {
                sum = sum + list.get(i).getX();
            } else {
                sum = sum + list.get(i).getY();
            }
        }

        return sum / list.size();
    }

    /**
     * 曼哈頓聚類距離
     */
    private Double ManHaDunDistance(Point pointOne, Point pointTwo) {

        return Math.abs(pointTwo.getX() - pointOne.getX()) +
                Math.abs(pointTwo.getY() - pointOne.getY());
    }

    
    public void typeList() {
        for (int i = 0; i < list1.size(); i++) {
            Point point = list1.get(i);
            int index = original.indexOf(point);
            Log.d("cluster", "我來自聚類1----橫座標為:" + point.getX()
                    + "縱座標為:" + point.getY() +
                    "位於原始集合裡面的:" + (index + 1) + "位置");
        }
        for (int i = 0; i < list2.size(); i++) {
            Point point = list2.get(i);
            int index = original.indexOf(point);
            Log.d("cluster", "我來自聚類2----橫座標為:" + point.getX()
                    + "縱座標為:" + point.getY() +
                    "位於原始集合裡面的:" + (index + 1) + "位置");
        }
        for (int i = 0; i < list3.size(); i++) {
            Point point = list3.get(i);
            int index = original.indexOf(point);
            Log.d("cluster", "我來自聚類3----橫座標為:" + point.getX()
                    + "縱座標為:" + point.getY() +
                    "位於原始集合裡面的:" + (index + 1) + "位置");
        }
    }

    /**
     * 座標類
     */
    class Point {

        Double x;
        Double y;

        public Point(Double x, Double y) {
            this.x = x;
            this.y = y;
        }

        public Double getX() {
            return x;
        }

        public void setX(Double x) {
            this.x = x;
        }

        public Double getY() {
            return y;
        }

        public void setY(Double y) {
            this.y = y;
        }
    }
}

測試方法

 private void clusterCal() {
        KMeans kMeans = new KMeans();
        kMeans.initData();
        kMeans.getNextBasePointAndUpdateCluster();
        kMeans.typeList();
    }

執行結果  這裡需要糾正一個錯誤,就是聚類運算的次數,不是自己定義的,而是計算到一定程度,聚類基準點會不再發生變化,此時,代表計算完成。需要手動判斷,基準點是否還在發生變化。