機器學習小實戰(四) KMeans聚類
阿新 • • 發佈:2019-01-08
目錄
一、 KMeans聚類簡介
需要事先指定簇的數目k
優化目標:所有點到各自質心的距離之和最小
特點:受初始值(K個隨機質心的位置)的影響挺大的、受形狀的影響還有點大
二、小案例
讀資料、演算法例項化(設定引數),訓練模型、展示與分析
1. 讀取資料,瞭解一下
import numpy as np
import pandas as pd
import matplotlib as plt
beer=pd.read_csv('data.txt',sep=' ')
print(beer.shape) #(20, 5)
print(beer.head())
2. 資料預處理
給定資料集有5列,第一列是名字,與特徵沒什麼關係,所以將後面四列提取出來,作為接下來聚類的資料。
X=beer[['calories','sodium','alcohol','cost']]
3. KMeans聚類演算法
演算法例項化:指定簇的個數為3或2,然後將資料傳入進行訓練
from sklearn.cluster import KMeans
km_3=KMeans(n_clusters=3).fit(X) #一行完成演算法的例項化和傳入資料
km_2=KMeans(n_clusters=2).fit(X)
km_3.labels_
結果:array([0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 0, 0, 0, 1, 0, 0, 1, 2])
輸出的這個陣列中有3種取值:0,1,2對應三個不同的簇的標籤
beer['cluster_3']=km_3.labels_ #將標籤值(所在的是哪個簇)作為新的特徵存入csv中
beer['cluster_2']=km_2.labels_
beer.sort_values('cluster_3') #按值內容排序
4. 結果視覺化
多個特徵,做二維視覺化時就只能選擇2個特徵進行視覺化咯
from pandas.tools.plotting import scatter_matrix #散佈圖矩陣 cluster_centers3=km_3.cluster_centers_ cluster_centers2=km_2.cluster_centers_ beer.groupby('cluster_3').mean() #計算不同類別所對應的其他屬性的平均值 beer.groupby('cluster_2').mean() centers=beer.groupby('cluster_3').mean().reset_index() plt.rcParams['font.size']=14 colors=np.array(['red','green','blue','yello']) #散點圖中,兩個座標分別是calories和alcohol的取值 plt.scatter(beer['calories'],beer['alcohol'],c=colors[beer['cluster_3']]) plt.scatter(centers.calories,centers.alcohol,linewidth=3,marker='+',s=300,c='black') plt.xlabel('calories') plt.ylabel('alcohol')
結果:centers長啥樣——
四、 KMeans用於影象壓縮
1. 讀取影象
# -*- coding:utf-8 -*-這行太重要了,在anaconda程式設計時,要是忘記加上這行,可就顯示不出影象了呢
# -*- coding:utf-8 -*-
from skimage import io
from sklearn.cluster import KMeans
import numpy as np
image=io.imread('img.jpg')
io.imshow(image)
io.show()
#print(image.shape) #(647, 650, 3) 原來3個通道
2. KMeans壓縮
rows=image.shape[0]
cols=image.shape[1]
image=image.reshape(image.shape[0]*image.shape[1],3) #一張圖的畫素點排成一列,3表示3個通道的值
kmeans=KMeans(n_clusters=128,n_init=10,max_iter=200) #例項化kmeass,指定n為128(原來是256,並且還有3個通道呢)
kmeans.fit(image) # 例項化kmeans後,傳入物件image
clusters=np.asarray(kmeans.cluster_centers_,dtype=np.uint8)#把聚類之後的中心給取出來
labels=np.asarray(kmeans.labels_,dtype=np.uint8)
labels=labels.reshape(rows,cols)#變成二維的了,所以是灰度圖形式
print(clusters.shape) #(128, 3)
np.save('codebook_test.npy',clusters)
io.imsave('compressed_test.jpg',labels)
3. 儲存與顯示
image=io.imread('compressed_test.jpg')
io.imshow(image)
io.show()
太嚇人了!!!