1. 程式人生 > >聚類——MeanShift演算法以及Python實現

聚類——MeanShift演算法以及Python實現

均值漂移演算法(MeanShift)是一種旨在發現團(blobs)的聚類演算法

核心思想

尋找核密度極值點並作為簇的質心,然後根據最近鄰原則將樣本點賦予質心

演算法簡介

核密度估計

根據樣本分佈估計在樣本空間的每一點的密度。估計某點的密度時,核密度估計方法會考慮該點鄰近區域的樣本點的影響,鄰近區域大小由頻寬h決定,該引數對最終密度估計的影響非常大。通常採用高斯核:N(x)=12πhex22h2N(x)=\frac{1}{\sqrt{2\pi}h}e^{\frac{x^2}{2h^2}}

均值漂移

演算法初始化一個質心(向量表示),每一步迭代都會朝著當前質心領域內密度極值方向漂移,方向就是密度上升最大的方向,即梯度方向

參考資料)。求導即得漂移向量的終點就是下一個質心:xki+1=xjNK(xjxki)xjxjNK(xjxki)x_k^{i+1}=\frac{\sum_{x_j\in N}K(x_j-x_k^i)x_j}{\sum_{x_j\in N}K(x_j-x_k^i)}N是當前質心的領域內樣本集合。這個公式可以從另一方面理解:當前質心的鄰域內樣本點以核密度為權重的均值就是更新後的質心

演算法流程

  • Input: 高斯核頻寬,bin_seeding(是否對資料粗粒化), min_fre(可以作為起始質心的樣本點領域內的最少樣本數),閾值epsilon(漂移最小長度)
  • Output: 樣本簇標籤
  • Step1: 獲取可以作為起始質心的樣本點
  • Step2: 對每個起始質心進行漂移,漂移終止條件就是漂移距離小於epsilon。若漂移結束最終的質心與已存在的質心距離小於頻寬則合併
  • Step3: 分類。將樣本點歸屬到距離最近的質心中

程式碼

"""
meanshift聚類演算法
核心思想:
尋找核密度極值點並作為簇的質心,然後根據最近鄰原則將樣本點賦予質心
"""
from collections import defaultdict
import numpy as np
import math


class MeanShift:
    def __init__
(self, band_width=2.0, min_fre=3, epsilon=None, bin_seeding=False, bin_size=None): self.epsilon = epsilon if epsilon else 1e-3 * band_width self.bin_size = bin_size if bin_size else self.band_width self.band_width = band_width self.min_fre = min_fre # 可以作為起始質心的球體內最少的樣本數目 self.bin_seeding = bin_seeding self.radius2 = self.band_width ** 2 # 高維球體半徑的平方 self.N = None self.labels = None self.centers = [] def init_param(self, data): # 初始化引數 self.N = data.shape[0] self.labels = -1 * np.ones(self.N) return def get_seeds(self, data): # 獲取可以作為起始質心的點(seed) if not self.bin_seeding: return data seed_list = [] seeds_fre = defaultdict(int) for sample in data: seed = tuple(np.round(sample / self.bin_size)) # 將資料粗粒化,以防止非常近的樣本點都作為起始質心 seeds_fre[seed] += 1 for seed, fre in seeds_fre.items(): if fre >= self.min_fre: seed_list.append(np.array(seed)) if not seed_list: raise ValueError('the bin size and min_fre are not proper') if len(seed_list) == data.shape[0]: return data return np.array(seed_list) * self.bin_size def euclidean_dis2(self, center, sample): # 計算均值點到每個樣本點的歐式距離(平方) delta = center - sample return delta @ delta def gaussian_kel(self, dis2): # 計算高斯核 return 1.0 / self.band_width * (2 * math.pi) ** (-1.0 / 2) * math.exp(-dis2 / (2 * self.band_width ** 2)) def shift_center(self, current_center, data): # 計算下一個漂移的座標 denominator = 0 # 分母 numerator = np.zeros_like(current_center) # 分子, 一維陣列形式 for sample in data: dis2 = self.euclidean_dis2(current_center, sample) if dis2 <= self.radius2: d = self.gaussian_kel(dis2) denominator += d numerator += d * sample if denominator > 0: return numerator / denominator else: return None def classify(self, data): # 根據最近鄰將資料分類到最近的簇中 center_arr = np.array(self.centers) for i in range(self.N): delta = center_arr - data[i] dis2 = np.sum(delta * delta, axis=1) self.labels[i] = np.argmin(dis2) return def fit(self, data): # 訓練主函式 self.init_param(data) seed_list = self.get_seeds(data) for seed in seed_list: bad_seed = False current_center = seed # 進行一次獨立的均值漂移 while True: next_center = self.shift_center(current_center, data) if next_center is None: bad_seed = True break delta_dis = np.linalg.norm(next_center - current_center, 2) if delta_dis < self.epsilon: break current_center = next_center if not bad_seed: # 若該次漂移結束後,最終的質心與已存在的質心距離小於頻寬,則合併 for i in range(len(self.centers)): if np.linalg.norm(current_center - self.centers[i], 2) < self.band_width: break else: self.centers.append(current_center) self.classify(data) return if __name__ == '__main__': from sklearn.datasets import make_blobs data, label = make_blobs(n_samples=500, centers=5, cluster_std=1.2, random_state=7) MS = MeanShift(band_width=3, min_fre=3, bin_size=4, bin_seeding=True) MS.fit(data) labels = MS.labels print(MS.centers, np.unique(labels)) import matplotlib.pyplot as plt from itertools import cycle def visualize(data, labels): color = 'bgrymk' unique_label = np.unique(labels) for col, label in zip(cycle(color), unique_label): partial_data = data[np.where(labels == label)] plt.scatter(partial_data[:, 0], partial_data[:, 1], color=col) plt.show() return visualize(data, labels)

我的GitHub
注:如有不當之處,請指正。