聚類——譜聚類演算法以及Python實現
阿新 • • 發佈:2019-01-08
譜聚類(spectral cluster)可以視為一種改進的Kmeans的聚類演算法。常用來進行影象分割。缺點是需要指定簇的個數,難以構建合適的相似度矩陣。優點是簡單易實現。相比Kmeans而言,處理高維資料更合適。
核心思想
構建樣本點的相似度矩陣(圖),將圖切割成K個子圖,使得各個子圖內相似度最大,子圖間相似度最弱
演算法簡介
構建相似度矩陣的拉普拉斯矩陣。對拉普拉斯矩陣進行特徵值分解,選取前K(也是簇的個數)個特徵向量(按特徵值從小到大的順序)構成K維特徵空間,在特徵空間內進行Kmeans聚類。概括地講,就是將原始資料對映到特徵空間進行Kmeans聚類。因此,譜聚類適合於簇的個數比較小的情況下。
拉普拉斯矩陣可以分為規範化的()和未規範化的拉普拉斯矩陣()。其中D為圖的度矩陣(對角矩陣,節點邊權重之和),W為相似度矩陣。
演算法流程
- Input: 訓練資料集data,簇的個數, 閾值epsilon, 最大迭代次數maxstep, 相似度計算方法及引數
- Output: 標籤陣列
- Step1:構建相似度矩陣,再構建拉普拉斯矩陣,對拉普拉斯矩陣進行特徵值分解,將樣本資料點對映到特徵空間。
- Step2: 再特徵空間內進行Kmeans聚類。
程式碼
"""
譜聚類演算法
核心思想:構建樣本點的圖,切分圖,使得子圖內權重最大,子圖間權重最小
"""
import numpy as np
from kmeans import KMEANS
class Spectrum:
def __init__(self, n_cluster, epsilon=1e-3, maxstep=1000, method='unnormalized',
criterion='gaussian', gamma=2.0, dis_epsilon=70, k=5):
self.n_cluster = n_cluster
self.epsilon = epsilon
self.maxstep = maxstep
self.method = method # 本程式提供規範化以及非規範化的譜聚類演算法
self.criterion = criterion # 相似性矩陣的構建方法
self.gamma = gamma # 高斯方法中的sigma引數
self.dis_epsilon = dis_epsilon # epsilon-近鄰方法的引數
self.k = k # k近鄰方法的引數
self.W = None # 圖的相似性矩陣
self.L = None # 圖的拉普拉斯矩陣
self.L_norm = None # 規範化後的拉普拉斯矩陣
self.D = None # 圖的度矩陣
self.cluster = None
self.N = None
def init_param(self, data):
# 初始化引數
self.N = data.shape[0]
dis_mat = self.cal_dis_mat(data)
self.cal_weight_mat(dis_mat)
self.D = np.diag(self.W.sum(axis=1))
self.L = self.D - self.W
return
def cal_dis_mat(self, data):
# 計算距離平方的矩陣
dis_mat = np.zeros((self.N, self.N))
for i in range(self.N):
for j in range(i + 1, self.N):
dis_mat[i, j] = (data[i] - data[j]) @ (data[i] - data[j])
dis_mat[j, i] = dis_mat[i, j]
return dis_mat
def cal_weight_mat(self, dis_mat):
# 計算相似性矩陣
if self.criterion == 'gaussian': # 適合於較小樣本集
if self.gamma is None:
raise ValueError('gamma is not set')
self.W = np.exp(-self.gamma * dis_mat)
elif self.criterion == 'k_nearest': # 適合於較大樣本集
if self.k is None or self.gamma is None:
raise ValueError('k or gamma is not set')
self.W = np.zeros((self.N, self.N))
for i in range(self.N):
inds = np.argpartition(dis_mat[i], self.k + 1)[:self.k + 1] # 由於包括自身,所以+1
tmp_w = np.exp(-self.gamma * dis_mat[i][inds])
self.W[i][inds] = tmp_w
elif self.criterion == 'eps_nearest': # 適合於較大樣本集
if self.dis_epsilon is None:
raise ValueError('epsilon is not set')
self.W = np.zeros((self.N, self.N))
for i in range(self.N):
inds = np.where(dis_mat[i] < self.dis_epsilon)
self.W[i][inds] = 1.0 / len(inds)
else:
raise ValueError('the criterion is not supported')
return
def fit(self, data):
# 訓練主函式
self.init_param(data)
if self.method == 'unnormalized':
w, v = np.linalg.eig(self.L)
inds = np.argsort(w)[:self.n_cluster]
Vectors = v[:, inds]
elif self.method == 'normalized':
D = np.linalg.inv(np.sqrt(self.D))
L = D @ self.L @ D
w, v = np.linalg.eig(L)
inds = np.argsort(w)[:self.n_cluster]
Vectors = v[:, inds]
normalizer = np.linalg.norm(Vectors, axis=1)
normalizer = np.repeat(np.transpose([normalizer]), self.n_cluster, axis=1)
Vectors = Vectors / normalizer
else:
raise ValueError('the method is not supported')
km = KMEANS(self.n_cluster, self.epsilon, self.maxstep)
km.fit(Vectors)
self.cluster = km.cluster
return
if __name__ == '__main__':
from sklearn.datasets import make_blobs
from itertools import cycle
import matplotlib.pyplot as plt
data, label = make_blobs(centers=3, n_features=10, cluster_std=1.2, n_samples=500, random_state=1)
sp = Spectrum(n_cluster=3, method='unnormalized', criterion='gaussian', gamma=0.1)
sp.fit(data)
cluster = sp.cluster
# km = KMEANS(4)
# km.fit(data)
# cluster_km = km.cluster
# def visualize(data, cluster):
# color = 'bgrym'
# for col, inds in zip(cycle(color), cluster.values()):
# partial_data = data[inds]
# plt.scatter(partial_data[:, 0], partial_data[:, 1], color=col)
# plt.show()
# return
# visualize(data, cluster)
def cal_err(data, cluster):
# 計算MSE
mse = 0
for label, inds in cluster.items():
partial_data = data[inds]
center = partial_data.mean(axis=0)
for p in partial_data:
mse += (center - p) @ (center - p)
return mse / data.shape[0]
print(cal_err(data, cluster))
# print(cal_err(data, cluster_km))
我的GitHub
注:程式碼尚未進行嚴格測試,如有不當之處,請指正。