資料探勘筆記-聚類-SpectralClustering-原理與簡單實現
譜聚類(Spectral Clustering, SC)是一種基於圖論的聚類方法——將帶權無向圖劃分為兩個或兩個以上的最優子圖,使子圖內部儘量相似,而子圖間距離儘量距離較遠,以達到常見的聚類的目的。其中的最優是指最優目標函式不同,可以是Min Cut、Nomarlized Cut、Ratio Cut等。譜聚類能夠識別任意形狀的樣本空間且收斂於全域性最優解,其基本思想是利用樣本資料的相似矩陣(拉普拉斯矩陣)進行特徵分解後得到的特徵向量進行聚類。
Spectral Clustering 演算法步驟:
1)根據資料構造一個Graph,Graph的每一個節點對應一個數據點,將相似的點連線起來,並且邊的權重用於表示資料之間的相似度。把這個Graph
2)把W的每一列元素活者行元素加起來得到N個數,把它們放在對角線上(其他地方都是零),組成一個N*N的度矩陣,記為D 。
3)根據度矩陣與鄰接矩陣得出拉普拉斯矩陣 L = D - W 。
4)求出拉普拉斯矩陣L的前k個特徵值(除非特殊說明,否則“前k個”指按照特徵值的大小從小到大的順序)以及對應的特徵向量。
5)把這k個特徵(列)向量排列在一起組成一個N*k的矩陣,將其中每一行看作k維空間中的一個向量,並使用 K-Means演算法進行聚類。聚類的結果中每一行所屬的類別就是原來Graph中的節點亦即最初的N個數據點分別所屬的類別。
示例
Spectral Clustering
1)和 K-Medoids 類似,Spectral
Clustering 只需要資料之間的相似度矩陣就可以了,而不必像K-means那樣要求資料必須是 N 維歐氏空間中的向量。Spectral
Clustering 所需要的所有資訊都包含在 W 中。不過一般 W 並不總是等於最初的相似度矩陣——回憶一下,W 是我們構造出來的 Graph 的鄰接矩陣表示,通常我們在構造 Graph 的時候為了方便進行聚類,更加強到“區域性”的連通性,亦即主要考慮把相似的點連線在一起,比如:我們可以設定一個閾值,如果兩個點的相似度小於這個閾值,就把他們看作是不連線的。另一種構造
Graph 鄰接的方法是將 n 個與節點最相似的點與其連線起來。
2)由於抓住了主要矛盾,忽略了次要的東西,因此比傳統的聚類演算法更加健壯一些,對於不規則的誤差資料不是那麼敏感,而且效能也要好一些。許多實驗都證明了這一點。事實上,在各種現代聚類演算法的比較中,K-means 通常都是作為 baseline 而存在的。實際上
Spectral Clustering 是在用特徵向量的元素來表示原來的資料,並在這種“更好的表示形式”上進行 K-Means 。實際上這種“更好的表示形式”是用 Laplacian Eig進行降維的後的結果。而降維的目的正是“抓住主要矛盾,忽略次要的東西”。
3)計算複雜度比 K-means 要小。這個在高維資料上表現尤為明顯。例如文字資料,通常排列起來是維度非常高(比如幾千或者幾萬)的稀疏矩陣,對稀疏矩陣求特徵值和特徵向量有很高效的辦法,得到的結果是一些 k 維的向量(通常 k 不會很大),在這些低維的資料上做
K-Means 運算量非常小。但是對於原始資料直接做 K-Means 的話,雖然最初的資料是稀疏矩陣,但是 K-Means 中有一個求 Centroid 的運算,就是求一個平均值:許多稀疏的向量的平均值求出來並不一定還是稀疏向量,事實上,在文字資料裡,很多情況下求出來的 Centroid 向量是非常稠密,這時再計算向量之間的距離的時候,運算量就變得非常大,直接導致普通的 K-Means 巨慢無比,而 Spectral Clustering 等工序更多的演算法則迅速得多的結果。
Java簡單實現程式碼如下:
public class SpectralClusteringBuilder {
public static int DIMENSION = 30;
public static double THRESHOLD = 0.01;
public Data getInitData() {
Data data = new Data();
try {
String path = SpectralClustering.class.getClassLoader()
.getResource("測試").toURI().getPath();
DocumentSet documentSet = DocumentLoader.loadDocumentSet(path);
List<Document> documents = documentSet.getDocuments();
DocumentUtils.calculateTFIDF_0(documents);
DocumentUtils.calculateSimilarity(documents, new CosineDistance());
Map<String, Map<String, Double>> nmap = new HashMap<String, Map<String, Double>>();
Map<String, String> cmap = new HashMap<String, String>();
for (Document document : documents) {
String name = document.getName();
cmap.put(name, document.getCategory());
Map<String, Double> similarities = nmap.get(name);
if (null == similarities) {
similarities = new HashMap<String, Double>();
nmap.put(name, similarities);
}
for (DocumentSimilarity similarity : document.getSimilarities()) {
if (similarity.getDoc2().getName().equalsIgnoreCase(similarity.getDoc1().getName())) {
similarities.put(similarity.getDoc2().getName(), 0.0);
} else {
similarities.put(similarity.getDoc2().getName(), similarity.getDistance());
}
}
}
String[] docnames = nmap.keySet().toArray(new String[0]);
data.setRow(docnames);
data.setColumn(docnames);
data.setDocnames(docnames);
int len = docnames.length;
double[][] original = new double[len][len];
for (int i = 0; i < len; i++) {
Map<String, Double> similarities = nmap.get(docnames[i]);
for (int j = 0; j < len; j++) {
double distance = similarities.get(docnames[j]);
original[i][j] = distance;
}
}
data.setOriginal(original);
data.setCmap(cmap);
data.setNmap(nmap);
} catch (Exception e) {
e.printStackTrace();
}
return data;
}
/**
* 獲取距離閥值在一定範圍內的點
* @param data
* @return
*/
public double[][] getWByDistance(Data data) {
Map<String, Map<String, Double>> nmap = data.getNmap();
String[] docnames = data.getDocnames();
int len = docnames.length;
double[][] w = new double[len][len];
for (int i = 0; i < len; i++) {
Map<String, Double> similarities = nmap.get(docnames[i]);
for (int j = 0; j < len; j++) {
double distance = similarities.get(docnames[j]);
w[i][j] = distance < THRESHOLD ? 1 : 0;
}
}
return w;
}
/**
* 獲取距離最近的K個點
* @param data
* @return
*/
public double[][] getWByKNearestNeighbors(Data data) {
Map<String, Map<String, Double>> nmap = data.getNmap();
String[] docnames = data.getDocnames();
int len = docnames.length;
double[][] w = new double[len][len];
for (int i = 0; i < len; i++) {
List<Map.Entry<String, Double>> similarities =
new ArrayList<Map.Entry<String, Double>>(nmap.get(docnames[i]).entrySet());
sortSimilarities(similarities, DIMENSION);
for (int j = 0; j < len; j++) {
String name = docnames[j];
boolean flag = false;
for (Map.Entry<String, Double> entry : similarities) {
if (name.equalsIgnoreCase(entry.getKey())) {
flag = true;
break;
}
}
w[i][j] = flag ? 1 : 0;
}
}
return w;
}
/**
* 垂直求和
* @param W
* @return
*/
public double[][] getVerticalD(double[][] W) {
int row = W.length;
int column = W[0].length;
double[][] d = new double[row][column];
for (int j = 0; j < column; j++) {
double sum = 0;
for (int i = 0; i < row; i++) {
sum += W[i][j];
}
d[j][j] = sum;
}
return d;
}
/**
* 水平求和
* @param W
* @return
*/
public double[][] getHorizontalD(double[][] W) {
int row = W.length;
int column = W[0].length;
double[][] d = new double[row][column];
for (int i = 0; i < row; i++) {
double sum = 0;
for (int j = 0; j < column; j++) {
sum += W[i][j];
}
d[i][i] = sum;
}
return d;
}
/**
* 相似度排序,並取前K個,倒敘
* @param similarities
* @param k
*/
public void sortSimilarities(List<Map.Entry<String, Double>> similarities, int k) {
Collections.sort(similarities, new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Entry<String, Double> o1,
Entry<String, Double> o2) {
return o2.getValue().compareTo(o1.getValue());
}
});
while (similarities.size() > k) {
similarities.remove(similarities.size() - 1);
}
}
public void print(double[][] values) {
for (int i = 0, il = values.length; i < il; i++) {
for (int j = 0, jl = values[0].length; j < jl; j++) {
System.out.print(values[i][j] + " ");
}
System.out.println("\n");
}
}
// 隨機生成中心點,並生成初始的K個聚類
public List<DataPointCluster> genInitCluster(List<DataPoint> points, int k) {
List<DataPointCluster> clusters = new ArrayList<DataPointCluster>();
Random random = new Random();
Set<String> categories = new HashSet<String>();
while (clusters.size() < k) {
DataPoint center = points.get(random.nextInt(points.size()));
String category = center.getCategory();
if (categories.contains(category))
continue;
categories.add(category);
DataPointCluster cluster = new DataPointCluster();
cluster.setCenter(center);
cluster.getDataPoints().add(center);
clusters.add(cluster);
}
return clusters;
}
// 將點歸入到聚類中
public void handleCluster(List<DataPoint> points,
List<DataPointCluster> clusters, int iterNum) {
for (DataPoint point : points) {
DataPointCluster maxCluster = null;
double maxDistance = Integer.MIN_VALUE;
for (DataPointCluster cluster : clusters) {
DataPoint center = cluster.getCenter();
double distance = DistanceUtils.cosine(point.getValues(),
center.getValues());
if (distance > maxDistance) {
maxDistance = distance;
maxCluster = cluster;
}
}
if (null != maxCluster) {
maxCluster.getDataPoints().add(point);
}
}
// 終止條件定義為原中心點與新中心點距離小於一定閥值
// 當然也可以定義為原中心點等於新中心點
boolean flag = true;
for (DataPointCluster cluster : clusters) {
DataPoint center = cluster.getCenter();
DataPoint newCenter = cluster.computeMediodsCenter();
double distance = DistanceUtils.cosine(newCenter.getValues(),
center.getValues());
if (distance > 0.5) {
flag = false;
cluster.setCenter(newCenter);
}
}
if (!flag && iterNum < 25) {
for (DataPointCluster cluster : clusters) {
cluster.getDataPoints().clear();
}
handleCluster(points, clusters, ++iterNum);
}
}
/**
* KMeans方法
* @param dataPoints
*/
public void kmeans(List<DataPoint> dataPoints) {
List<DataPointCluster> clusters = genInitCluster(dataPoints, 4);
handleCluster(dataPoints, clusters, 0);
int success = 0, failure = 0;
for (DataPointCluster cluster : clusters) {
String category = cluster.getCenter().getCategory();
for (DataPoint dataPoint : cluster.getDataPoints()) {
String dpCategory = dataPoint.getCategory();
if (category.equals(dpCategory)) {
success++;
} else {
failure++;
}
}
}
System.out.println("total: " + (success + failure) + " success: "
+ success + " failure: " + failure);
}
public void build() {
Data data = getInitData();
double[][] w = getWByKNearestNeighbors(data);
double[][] d = getHorizontalD(w);
Matrix W = new Matrix(w);
Matrix D = new Matrix(d);
Matrix L = D.minus(W);
EigenvalueDecomposition eig = L.eig();
double[][] v = eig.getV().getArray();
double[][] vs = new double[v.length][DIMENSION];
for (int i = 0, li = v.length; i < li; i++) {
for (int j = 1, lj = DIMENSION; j <= lj; j++) {
vs[i][j-1] = v[i][j];
}
}
Matrix V = new Matrix(vs);
Matrix O = new Matrix(data.getOriginal());
double[][] t = O.times(V).getArray();
List<DataPoint> dataPoints = new ArrayList<DataPoint>();
for (int i = 0; i < t.length; i++) {
DataPoint dataPoint = new DataPoint();
dataPoint.setCategory(data.getCmap().get(data.getColumn()[i]));
dataPoint.setValues(t[i]);
dataPoints.add(dataPoint);
}
for (int n = 0; n < 10; n++) {
kmeans(dataPoints);
}
}
public static void main(String[] args) {
new SpectralClusteringBuilder().build();
}
}
程式碼託管:https://github.com/fighting-one-piece/repository-datamining.git