1. 程式人生 > >k-means用於影象的顏色聚類

k-means用於影象的顏色聚類

#coding:utf-8
__author__ = 'jmh081701'
#本檔案主要學習一種經典的聚類方法:k-means
#我們把這個演算法用於一個RGB影象的聚類,看能出來的什麼的效果
#k-means的原理:
'''
輸入:x[1],x[2],x[3],...,x[n],其中每個x[i]都是m維的向量,給定聚類的數目k
1.隨機生成k個代表元:z[1],z[2],...,z[k];每個z[i]都是第i類的中心元
2.repeat:
  更新 xi所述的類別ci,使得:|x[i]-z[ci]|最小
  更新 z[j],z[j]等於所在類別G[j]的所有樣本的平均值
until:z不再改變
'''
import numpy as np import math import random from PIL import Image cnt=0 def calculate_zi(Gi,X): #給定Gi,裡面包含著屬於這個類別的元素,然後計算這些元素的中心點 #在本例項中,Gi裡面包含的是下標 global cnt sumi=np.zeros(len(X[0])) for each in Gi: cnt+=1 sumi+=X[each] sumi/=(len(Gi)+0.000000001) zi=sumi return zi def
find_ci(xi,Z):
#尋找離xi最近的中心元素ci,使得Z[ci]與xi之間的向量差的內積最小 global cnt dis_= np.inf len_=len(Z) rst_index = None for i in range(len_): cnt+=1 tmp_dist=np.dot(xi-Z[i],np.transpose(xi-Z[i])) if tmp_dist<dis_: rst_index=i dis_=tmp_dist return
rst_index def k_mean(X,k): G=[] #G[i]={1,2,3...}表示屬於第i類的樣本在X中的索引,洗標 Z=[] #Z[i] 第i類的中心點 N=len(X) c=[] #c[i]=1,2,...,k;表示第i個樣本屬於第c[i]類 tmpr=set() while len(Z)<k: r=random.randint(0,len(X)-1) if r not in tmpr: tmpr.add(r) Z.append(X[r]) G.append(set()) for i in range(N): c.append(0) #隨機生成K箇中心元素 while True: group_flag=np.zeros(k) for i in range(N): new_ci = find_ci(X[i],Z) if c[i] != new_ci: #找到了更好的,把xi從原來的c[i]調到new_ci去,於是有兩個組需要更新:new_ci,c[i] if i in G[c[i]]: G[c[i]].remove(i) group_flag[c[i]]=1 #把i從原來所屬的組中移出來 G[new_ci].add(i) group_flag[new_ci]=1 #把i加入到新的所屬組去 c[i]=new_ci #上面已經更新好了各元素的所屬 if np.sum(group_flag)==0: #沒有組被修改 break for i in range(k): if group_flag[i]==0: #未修改,無須重新計算 continue else: Z[i]=calculate_zi(list(G[i]),X) return Z,c,k def test_rgb_img(): filename=r"1.jpg" im = Image.open(filename) img = im.load() im.close() height = im.size[0] width= im.size[1] print(im.size) X=[] for i in range(0,height): for j in range(0,width): X.append(np.array(img[i,j])) Z,c,k=k_mean(X,8) #print(Z) new_im = Image.new("RGB",(height,width)) for i in range(0,height): for j in range(0,width): index = i * width + j pix = list(Z[c[index]]) for k in range(len(pix)): pix[k]=int(pix[k]) new_im.putpixel((i,j),tuple(pix)) new_im.show() if __name__ == '__main__': test_rgb_img() print(cnt)

原圖:
這裡寫圖片描述
k=8的聚類結果:

這裡寫圖片描述

k=4的聚類結果:
這裡寫圖片描述

k=2:聚類結果

這裡寫圖片描述