1. 程式人生 > >通過Python實踐K-means算法

通過Python實踐K-means算法

對比 散點 分享 k-means append 平均值 算法 ack nump

前言:

今天在宿舍弄了一個下午的代碼,總算還好,把這個東西算是熟悉了,還不算是力竭,只算是知道了怎麽回事。今天就給大家分享一下我的代碼。代碼可以運行,運行的Python環境是Python3.6以上的版本,需要用到Python中的numpy、matplotlib包,這一部分代碼將K-means算法進行了實現。當然這還不是最優的代碼,只是在功能上已經實現了該算法的功能。

代碼部分:

  1 import numpy as np
  2 import random
  3 from matplotlib import pyplot as plt
  4 
  5 class K_means(object):
6 def __init__(self,X,k,maxIter): 7 self.X = X#數據集 是一個矩陣 8 self.k = k#所需要分的類的數 9 self.maxIter = maxIter#所允許的程序執行的最大的循環次數 10 11 def K_means(self): 12 row,col = self.X.shape#得到矩陣的行和列 13 14 dataset = np.zeros((row,col + 1))#新生成一個矩陣,行數不變,列數加1 新的列用來存放分組號別 矩陣中的初始值為0
15 dataset[:,:-1] = self.X 16 print("begin:dataset:\n" + repr(dataset)) 17 # centerpoints = dataset[0:2,:]#取數據集中的前兩個點為中心點 18 centerpoints = dataset[np.random.randint(row,size=k)]#采用隨機函數任意取兩個點 19 20 centerpoints[:,-1] = range(1,self.k+1) 21 oldCenterpoints = None #
用來在循環中存放上一次循環的中心點 22 iterations = 1 #當前循環次數 23 24 while not self.stop(oldCenterpoints,centerpoints,iterations): 25 print("corrent iteration:" + str(iterations)) 26 print("centerpoint:\n" + repr(centerpoints)) 27 print("dataset:\n" + repr(dataset)) 28 29 oldCenterpoints = np.copy(centerpoints)#將本次循環的點拷貝一份 記錄下來 30 iterations += 1 31 32 self.updateLabel(dataset,centerpoints)#將本次聚類好的結果存放到矩陣中 33 34 centerpoints = self.getCenterpoint(dataset)#得到新的中心點,再次進行循環計算 35 36 np.save("kmeans.npy", dataset) 37 return dataset 38 39 def stop(self,oldCenterpoints,centerpoints,iterations): 40 if iterations > self.maxIter: 41 return True 42 return np.array_equal(oldCenterpoints,centerpoints)#返回兩個點多對比結果 43 44 45 def updateLabel(self,dataset,centerpoints): 46 row,col = self.X.shape 47 for i in range(0,row): 48 dataset[i,-1] = self.getLabel(dataset[i,:-1],centerpoints) 49 #[i,j] 表示i行j列 50 51 #返回當前行和中心點之間的距離最短的中心點的類別,即當前點和那個中心點最近就被劃分到哪一部分 52 def getLabel(self,datasetRow,centerpoints): 53 label = centerpoints[0, -1]#先取第一行的標簽值賦值給該變量 54 minDist = np.linalg.norm(datasetRow-centerpoints[0, :-1])#計算兩點之間的直線距離 55 for i in range(1, centerpoints.shape[0]): 56 dist = np.linalg.norm(datasetRow-centerpoints[i, :-1]) 57 if dist < minDist:#當該變距離中心點的距離小於預設的最小值,那麽將最小值進行更新 58 minDist = dist 59 label = centerpoints[i,-1] 60 print("minDist:" + str(minDist) + ",belong to label:" + str(label)) 61 return label 62 63 def getCenterpoint(self,dataset): 64 newCenterpoint = np.zeros((self.k,dataset.shape[1]))#生成一個新矩陣,行是k值,列是數據集的列的值 65 for i in range(1,self.k+1): 66 oneCluster = dataset[dataset[:,-1] == i,:-1]#取出上一次分好的類別的所有屬於同一類的點,對其求平均值 67 newCenterpoint[i-1, :-1] = np.mean(oneCluster,axis=0)#axis=0表示對行求平均值,=1表示對列求平均值 68 newCenterpoint[i-1, -1] = i#重新對新的中心點進行分類,初始類 69 70 return newCenterpoint 71 72 #將散點圖畫出來 73 def drawScatter(self): 74 plt.xlabel("X") 75 plt.ylabel("Y") 76 dataset = self.K_means() 77 x = dataset[:, 0] # 第一列的數值為橫坐標 78 y = dataset[:, 1] # 第二列的數值為縱坐標 79 c = dataset[:, -1] # 最後一列的數值用來區分顏色 80 color = ["none", "b", "r", "g", "y","m","c","k"] 81 c_color = [] 82 83 for i in c: 84 c_color.append(color[int(i)])#給每一種類別的點都塗上不同顏色,便於觀察 85 86 plt.scatter(x=x, y=y, c=c_color, marker="o")#其中x表示橫坐標的值,y表示縱坐標的 87 # 值,c表示該點顯示出來的顏色,marker表示該點多形狀,‘o’表示圓形 88 plt.show() 89 90 91 if __name__ == __main__: 92 93 94 ‘‘‘ 95 關於numpy中的存儲矩陣的方法,這裏不多介紹,可以自行百度。這裏使用的是 96 np.save("filename.npy",X)其中X是需要存儲的矩陣 97 讀取的方法就是代碼中的那一行代碼,可以不用修改任何參數,導出來的矩陣和保存之前的格式一模一樣,很方便。 98 ‘‘‘ 99 # X = np.load("testSet-kmeans.npy")#從文件中讀取數據 100 #自動生成數據 101 X = np.zeros((1,2)) 102 for i in range(1000): 103 X = np.row_stack((X,np.array([random.randint(1,100),random.randint(1,100)]))) 104 k = 5 #表示待分組的組數 105 106 kmeans = K_means(X=X,k=k,maxIter=100) 107 kmeans.drawScatter()

顯示效果:

技術分享圖片

通過Python實踐K-means算法