1. 程式人生 > >k-means++演算法的c++實現

k-means++演算法的c++實現

k-means++是機器學習領域一種基本的聚類演算法,是k-means演算法的增強版,與k-means演算法的唯一區別就在於初始點的選擇上。眾所周知, 通常情況下,k-means選擇初始點都是以一種隨機的方式選擇的,選擇的初始點的好壞,對聚類的效果以及演算法的迭代次數上都有很明顯的影響。最壞的情況如有兩個初始點選在了同一個聚類中,那麼最終有可能導致原本屬於一個聚類的點被分成了兩類。

針對上述k-means的問題,k-means++演算法對初始點的選擇採用一些策略,從而大大改善了演算法的有效性。k-means++演算法是這樣的:

假設:

a:將資料聚成k類;

b:x表示資料集中的任一資料點;

c:Di表示第i個數據點與距離其最近的聚類中心之間的距離平方。

1、隨機生成資料集(可以是任意維的,為了演示方便,我只採用了二維);

2、在資料集中隨機選擇一個數據點,作為我們的第一個聚類中心C1;

3、以概率Di/sum(Di)選擇第i個數據點作為下一個聚類中心;

4、重複3,直到已經找到k個聚類中心{C1,C2,...,Ck};

5、執行k-means演算法。

值得注意的是,可能有人會對第三步以概率Di/sum(Di)選點不是很明白,或者說這到底是怎麼樣的一種選法,應該在程式中如何體現,我先貼出原始碼中一部分:

template<typename Real, int Dim>
void KMeans<Real, Dim>::kpp(vector<KmPoint> &pts, vector<KmPoint> &cents){
	Real sum = 0;
	vector<Real> d;
	d.resize(pts.size());
	cents[0] = pts[rand() % pts.size()];
	vector<KmPoint> tmpCents;
	tmpCents.push_back(cents[0]);
	for(int k = 1; k < (int)cents.size(); ++k){
		sum = 0;
		for(int i = 0; i < (int)pts.size(); ++i){
			nearest(pts[i], tmpCents, d[i]);
			sum += d[i];
		}
		sum = randf(sum);
		for(int i = 0; i < (int)pts.size(); ++i){
			<strong>if((sum -= d[i]) > 0)	continue;</strong>
			cents[k] = pts[i];
			tmpCents.push_back(cents[k]);
			break;
		}
	}
	for(int i = 0; i < (int)pts.size(); ++i){
		int id = nearest(pts[i], cents, *(new Real));
		pts[i].setId(id);
	}
}

也許大家可以從這段程式碼中窺得一些思想,我們知道,概率事件在程式中是可以用隨機數模擬的,不錯,這裡以概率p選擇下一個聚類中心正是利用了用隨機數模擬概率事件的思想。我們可以這樣理解上述這段程式碼,首先我們已經統計出了所有點到每一個聚類中心的最近距離,這個距離存放在d這個陣列中,所有最近距離的和為sum,然後隨機的從0~sum之間選擇一個數(需要注意,這裡sum是所有距離的和),注意程式碼中加黑加亮部分,可以想象兩種情況,一種d[i]很大,記為dmax,一種d[i]很小,記為dmin,那麼sum - dmax是不是比sum - dmin更容易打破大於0這個條件?並且,一旦大於0的條件被打破,接下來就是選擇對應於d[i]的點作為我們的聚類中心。是的,以Di/sum[Di]的概率選擇下一個聚類中心的思想正是在這裡得以體現,可以說非常巧妙。k-means++的原理和演算法部分就說完了,其實主要是演算法的第三步,很多網上的資料都沒有對此作出說明,只是機械式的給出了演算法的步驟,希望我的理解能解開困擾一些人的疑惑。

下面是完整的原始碼:

//KmLib.h

#ifdef _WIN32
#ifdef KM_EXPORTS
#define KM_API __declspec(dllexport)
#else
#define KM_API __declspec(dllimport)
#endif
#else
#define KM_API
#endif

//KmPoint.h
#ifndef KMPOINT_H
#define KMPOINT_H
#include "KmLib.h"

template<typename Real = double, int Dim = 2>
struct KM_API KmPoint{
public:
	typedef KmPoint<Real, Dim> Self;

	KmPoint(){
		for(int i = 0; i < Dim; ++i){
			x[i] = 0;
		}
		gid = 0;
	}

	KmPoint(Real (&x)[Dim], int gid = -1){
		for(int i = 0; i < Dim; ++i){
			this->x[i] = x[i];
		}
		this->gid = gid;
	}

	KmPoint(const Self &other){
		for(int i = 0; i < Dim; ++i){
			x[i] = other[i];
		}
		gid = other.gid;
	}

	Real &operator[](int idx){	return x[idx];	}
	const Real &operator[](int idx) const {	return x[idx]; }

	int id(){ return gid; }
	void setId(int id){
		gid = id;
	}

	Real get(int idx){ return x[idx]; }
private:
	Real x[Dim];
	int gid;
};

template<class Real, int Dim>
ostream& operator<<(ostream &out, KmPoint<Real, Dim> &other){
	out<<other[0];
	for(int i = 1; i < Dim; ++i)
		out<<" "<<other[i];
	out<<" "<<other.id();
	return out;
}

#endif

//Kmeans.h
#ifndef KMEANS_H
#define KMEANS_H
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <vector>
#include <fstream>
#include <iostream>
using namespace std;
#include "KmPoint.h"
#define PI 3.14159265358979323846

template<typename Real = double, int Dim = 2>
class KM_API KMeans{
public:
	typedef KmPoint<Real, Dim> KmPoint;
public:
	KMeans(){ srand((unsigned)time(0)); }
	Real randf(Real m){
		return m * rand() / (RAND_MAX - 1.);
	}

	Real dist(KmPoint &a, KmPoint &b){
		Real len = 0;
		for(int i = 0; i < Dim; ++i){
			len += (a[i] - b[i]) * (a[i] - b[i]);
		}
		return len;
	}

	void dataGenerator(int count, Real radius, vector<KmPoint> &pts);
	int nearest(KmPoint &pt, vector<KmPoint> &cents, Real &d);
	void kpp(vector<KmPoint> &pts, vector<KmPoint> &cents);
	void kmcluster(vector<KmPoint> &pts, int k, vector<KmPoint> &outCents, vector<vector<KmPoint>> &outPts);
	void serialize(vector<vector<KmPoint>> &outPts);
};

#endif

//Kmeans.cpp
#include "Kmeans.h"

template<typename Real, int Dim>
void KMeans<Real, Dim>::dataGenerator(int count, Real radius, vector<KmPoint> &pts){
	if(pts.size() <= 0)
		pts.resize(count);
	for(int i = 0; i < count; ++i){
		Real ang = randf(2 * PI);
		Real r = randf(radius);
		KmPoint p;
		p[0] = r * cos((Real)ang);
		p[1] = r * sin((Real)ang);
		pts[i] = p;
	}
}

template<typename Real, int Dim>
int KMeans<Real, Dim>::nearest(KmPoint &pt, vector<KmPoint> &cents, Real &d){
	int i, min_i;
	Real d1, min_d;
	min_d = HUGE_VAL;
	min_i = pt.id();
	for(i = 0; i < (int)cents.size(); ++i){
		KmPoint c = cents[i];
		if(min_d > (d1 = dist(c, pt))){
			min_d = d1;
			min_i = i;
		}
	}
	d = min_d;
	return min_i;
}

template<typename Real, int Dim>
void KMeans<Real, Dim>::kpp(vector<KmPoint> &pts, vector<KmPoint> &cents){
	Real sum = 0;
	vector<Real> d;
	d.resize(pts.size());
	cents[0] = pts[rand() % pts.size()];
	vector<KmPoint> tmpCents;
	tmpCents.push_back(cents[0]);
	for(int k = 1; k < (int)cents.size(); ++k){
		sum = 0;
		for(int i = 0; i < (int)pts.size(); ++i){
			nearest(pts[i], tmpCents, d[i]);
			sum += d[i];
		}
		sum = randf(sum);
		for(int i = 0; i < (int)pts.size(); ++i){
			if((sum -= d[i]) > 0)	continue;
			cents[k] = pts[i];
			tmpCents.push_back(cents[k]);
			break;
		}
	}
	for(int i = 0; i < (int)pts.size(); ++i){
		int id = nearest(pts[i], cents, *(new Real));
		pts[i].setId(id);
	}
}

template<typename Real, int Dim>
void KMeans<Real, Dim>::kmcluster(vector<KmPoint> &pts, int k, vector<KmPoint> &outCents, vector<vector<KmPoint>> &outPts){
	if(outCents.size() <= 0)
		outCents.resize(k);
	if(outPts.size() <= 0)
		outPts.resize(k);
	kpp(pts, outCents);
	int changed;
	do{
		for(int i = 0; i < (int)outCents.size(); ++i){
			for(int j = 0; j < Dim; ++j)
				outCents[i][j] = 0;
			outCents[i].setId(0);
		}
		vector<int> cnt(k, 0);
		for(int i = 0; i < (int)pts.size(); ++i){
			int k = pts[i].id();
			for(int j = 0; j < Dim; ++j)
				outCents[k][j] += pts[i][j];
			cnt[k]++;
		}
		for(int i = 0; i < (int)outCents.size(); ++i){
			for(int j = 0; j < Dim; ++j)
				outCents[i][j] /= cnt[i];
		}
		changed = 0;
		for(int i = 0; i < (int)pts.size(); ++i){
			int min_i = nearest(pts[i], outCents, *(new Real));
			if(min_i != pts[i].id()){
				changed++;
				pts[i].setId(min_i);
			}
		}
	}while(changed > 0.001 * pts.size());
	for(int i = 0; i < (int)outCents.size(); ++i)
		outCents[i].setId(i);
	for(int i = 0; i < (int)pts.size(); ++i)
		outPts[pts[i].id()].push_back(pts[i]);
}

template<typename Real, int Dim>
void KMeans<Real, Dim>::serialize(vector<vector<KmPoint>> &outPts){
	ofstream ofs("./cluster.txt", ofstream::out);
	if(!ofs.is_open()){
		cout<<"open file failed!"<<endl;
		return ;
	}
	for(int i = 0; i < (int)outPts.size(); ++i){
		for(int j = 0; j < (int)outPts[i].size(); ++j){
			ofs<<outPts[i][j]<<endl;
		}
	}
	ofs.close();
}

//main
#include "Kmeans.cpp"

int main(){
	KMeans<> km;
	vector<KmPoint<>> pts, outCents;
	int k = 7;
	vector<vector<KmPoint<>>> outPts;
	km.dataGenerator(10000, 10, pts);
	km.kmcluster(pts, k, outCents, outPts);
	km.serialize(outPts);
	return 0;
}

以下是聚類結果:

7個聚類


11個聚類


<pre code_snippet_id="331410" snippet_file_name="blog_20140507_2_7888148">