1. 程式人生 > 其它 >程式設計實現線性判別分析,並給出西瓜資料集3.0α上的執行結果

程式設計實現線性判別分析,並給出西瓜資料集3.0α上的執行結果

線性判別分析

1.題目理解

將西瓜資料集的樣例投影到一條直線上,使得好瓜、壞瓜各自的投影點儘可能接近,好瓜與壞瓜之間的投影點儘可能遠離。

2.演算法原理

3.演算法設計

① 根據LDA原理求解得到w,結合資料集得到LDA直線;

② 將每個樣本對映到LDA直線上,觀察分析結果。

4.關鍵程式碼

 1 # 載入資料集
 2 dataset = np.loadtxt('C:/Users/86185/PycharmProjects/ML1/watermelon_3a.csv', delimiter=",")
 3 
 4 # 分離屬性值和標籤
 5 X = dataset[:,1:3]
 6 y = dataset[:,3]
7 u = [] 8 for i in range(2): 9 u.append(np.mean(X[y==i],axis=0)) 10 11 m,n = np.shape(X) 12 Sw = np.zeros((n,n)) 13 for i in range(m): 14 x_temp = X[i].reshape(n, 1) # 行向量變為列向量 15 if y[i]==0: u_temp = u[0].reshape(n, 1) 16 if y[i]==1: u_temp = u[1].reshape(n, 1) 17 Sw +=np.dot(x_temp-u_temp, (x_temp-u_temp).T)
18 19 Sw = np.mat(Sw) 20 # print(Sw) 21 Sw_inv = np.linalg.inv (Sw) 22 # print(Sw_inv) 23 w = np.dot(Sw_inv, (u[0]-u[1]).reshape(n,1)) 24 print(w)

先根據公式求得w

 1 def GetPoint(point0, w):
 2     k0 = w[1, 0]/w[0, 0]
 3     k1 = w[0, 0]/w[1, 0]
 4     x0 = point0[0]
 5     y0 = point0[1]
 6     x1 = (k0 * x0 - y0) / (k0 + k1)
7 y1 = -k1 * x1 8 return x0, x1, y0, y1 9 10 f1 = plt.figure('first') 11 plt.xlim( -0.2, 1 ) # 設定座標軸的範圍 12 plt.ylim( -0.2, 0.6 ) 13 14 15 x = np.arange(-1, 3) 16 yy = -(w[0,0]/w[1,0])*x

做LDA直線yy;GetPoint()函式用來計算點到直線yy的投影

5.結果展示

根據執行結果顯示,沒有很明確的將好瓜與壞瓜區分開來,好瓜與壞瓜的投影點不夠遠離,壞瓜與壞瓜之間的投影點不夠聚集。