1. 程式人生 > 其它 >簡單易學的機器學習演算法——Mean Shift聚類演算法

簡單易學的機器學習演算法——Mean Shift聚類演算法

一、Mean Shift演算法概述

Mean Shift演算法,又稱為均值漂移演算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在後來由Yizong Cheng對其進行擴充,主要提出了兩點的改進:

  • 定義了核函式;
  • 增加了權重係數。

核函式的定義使得偏移值對偏移向量的貢獻隨之樣本與被偏移點的距離的不同而不同。權重係數使得不同樣本的權重不同。Mean Shift演算法在聚類,影象平滑、分割以及視訊跟蹤等方面有廣泛的應用。

二、Mean Shift演算法的核心原理

2.1、核函式

上圖的畫圖指令碼如下所示:

'''
Date:201604026
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
import math

def cal_Gaussian(x, h=1):
    molecule = x * x
    denominator = 2 * h * h
    left = 1 / (math.sqrt(2 * math.pi) * h)
    return left * math.exp(-molecule / denominator)

x = []

for i in xrange(-40,40):
    x.append(i * 0.5);

score_1 = []
score_2 = []
score_3 = []
score_4 = []

for i in x:
    score_1.append(cal_Gaussian(i,1))
    score_2.append(cal_Gaussian(i,2))
    score_3.append(cal_Gaussian(i,3))
    score_4.append(cal_Gaussian(i,4))

plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")

plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()

2.2、Mean Shift演算法的核心思想

2.2.1、基本原理

對於Mean Shift演算法,是一個迭代的步驟,即先算出當前點的偏移均值,將該點移動到此偏移均值,然後以此為新的起始點,繼續移動,直到滿足最終的條件。此過程可由下圖的過程進行說明(圖片來自參考文獻3):

  • 步驟1:在指定的區域內計算偏移均值(如下圖的黃色的圈)
  • 步驟2:移動該點到偏移均值點處
  • 步驟3: 重複上述的過程(計算新的偏移均值,移動)
  • 步驟4:滿足了最終的條件,即退出

從上述過程可以看出,在Mean Shift演算法中,最關鍵的就是計算每個點的偏移均值,然後根據新計算的偏移均值更新點的位置。

2.2.2、基本的Mean Shift向量形式

2.2.3、改進的Mean Shift向量形式

2.3、Mean Shift演算法的解釋

在Mean Shift演算法中,實際上是利用了概率密度,求得概率密度的區域性最優解。

2.3.1、概率密度梯度

2.3.2、Mean Shift向量的修正

2.4、Mean Shift演算法流程

三、實驗

3.1、實驗資料

實驗資料如下圖所示(來自參考文獻1):

畫圖的程式碼如下:

'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt

f = open("data")
x = []
y = []
for line in f.readlines():
    lines = line.strip().split("t")
    if len(lines) == 2:
        x.append(float(lines[0]))
        y.append(float(lines[1]))
f.close()  

plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()

3.2、實驗的原始碼

#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''

import math
import sys
import numpy as np

MIN_DISTANCE = 0.000001#mini error

def load_data(path, feature_num=2):
    f = open(path)
    data = []
    for line in f.readlines():
        lines = line.strip().split("t")
        data_tmp = []
        if len(lines) != feature_num:
            continue
        for i in xrange(feature_num):
            data_tmp.append(float(lines[i]))

        data.append(data_tmp)
    f.close()
    return data

def gaussian_kernel(distance, bandwidth):
    m = np.shape(distance)[0]
    right = np.mat(np.zeros((m, 1)))
    for i in xrange(m):
        right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
        right[i, 0] = np.exp(right[i, 0])
    left = 1 / (bandwidth * math.sqrt(2 * math.pi))

    gaussian_val = left * right
    return gaussian_val

def shift_point(point, points, kernel_bandwidth):
    points = np.mat(points)
    m,n = np.shape(points)
    #計算距離
    point_distances = np.mat(np.zeros((m,1)))
    for i in xrange(m):
        point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)

    #計算高斯核      
    point_weights = gaussian_kernel(point_distances, kernel_bandwidth)

    #計算分母
    all = 0.0
    for i in xrange(m):
        all += point_weights[i, 0]

    #均值偏移
    point_shifted = point_weights.T * points / all
    return point_shifted

def euclidean_dist(pointA, pointB):
    #計算pointA和pointB之間的歐式距離
    total = (pointA - pointB) * (pointA - pointB).T
    return math.sqrt(total)

def distance_to_group(point, group):
    min_distance = 10000.0
    for pt in group:
        dist = euclidean_dist(point, pt)
        if dist < min_distance:
            min_distance = dist
    return min_distance

def group_points(mean_shift_points):
    group_assignment = []
    m,n = np.shape(mean_shift_points)
    index = 0
    index_dict = {}
    for i in xrange(m):
        item = []
        for j in xrange(n):
            item.append(str(("%5.2f" % mean_shift_points[i, j])))

        item_1 = "_".join(item)
        print item_1
        if item_1 not in index_dict:
            index_dict[item_1] = index
            index += 1

    for i in xrange(m):
        item = []
                for j in xrange(n):
                        item.append(str(("%5.2f" % mean_shift_points[i, j])))

                item_1 = "_".join(item)
        group_assignment.append(index_dict[item_1])

    return group_assignment

def train_mean_shift(points, kenel_bandwidth=2):
    #shift_points = np.array(points)
    mean_shift_points = np.mat(points)
    max_min_dist = 1
    iter = 0
    m, n = np.shape(mean_shift_points)
    need_shift = [True] * m

    #cal the mean shift vector
    while max_min_dist > MIN_DISTANCE:
        max_min_dist = 0
        iter += 1
        print "iter : " + str(iter)
        for i in range(0, m):
            #判斷每一個樣本點是否需要計算偏置均值
            if not need_shift[i]:
                continue
            p_new = mean_shift_points[i]
            p_new_start = p_new
            p_new = shift_point(p_new, points, kenel_bandwidth)
            dist = euclidean_dist(p_new, p_new_start)

            if dist > max_min_dist:#record the max in all points
                max_min_dist = dist
            if dist < MIN_DISTANCE:#no need to move
                need_shift[i] = False

            mean_shift_points[i] = p_new
    #計算最終的group
    group = group_points(mean_shift_points)

    return np.mat(points), mean_shift_points, group

if __name__ == "__main__":
    #匯入資料集
    path = "./data"
    data = load_data(path, 2)

    #訓練,h=2
    points, shift_points, cluster = train_mean_shift(data, 2)

    for i in xrange(len(cluster)):
        print "%5.2f,%5.2ft%5.2f,%5.2ft%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])

3.3、實驗的結果

經過Mean Shift演算法聚類後的資料如下所示:

'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt

f = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}

for line in f.readlines():
    lines = line.strip().split("t")
    if len(lines) == 3:
        label = int(lines[2])
        if label == 0:
            data_1 = lines[0].strip().split(",")
            cluster_x_0.append(float(data_1[0]))
            cluster_y_0.append(float(data_1[1]))
            if label not in center_dict:
                center_dict[label] = 1
                data_2 = lines[1].strip().split(",")
                center_x.append(float(data_2[0]))
                center_y.append(float(data_2[1]))
        elif label == 1:
            data_1 = lines[0].strip().split(",")
            cluster_x_1.append(float(data_1[0]))
            cluster_y_1.append(float(data_1[1]))
            if label not in center_dict:
                center_dict[label] = 1
                data_2 = lines[1].strip().split(",")
                center_x.append(float(data_2[0]))
                center_y.append(float(data_2[1]))
        else:
            data_1 = lines[0].strip().split(",")
            cluster_x_2.append(float(data_1[0]))
            cluster_y_2.append(float(data_1[1]))
            if label not in center_dict:
                center_dict[label] = 1
                data_2 = lines[1].strip().split(",")
                center_x.append(float(data_2[0]))
                center_y.append(float(data_2[1]))    
f.close()


plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')
#plt.legend(loc="best")
plt.show()

參考文獻

  1. Mean Shift Clustering
  2. Meanshift,聚類演算法
  3. meanshift演算法簡介