程式設計實現線性判別分析,並給出西瓜資料集3.0α上的執行結果
阿新 • • 發佈:2022-03-03
線性判別分析
1.題目理解
將西瓜資料集的樣例投影到一條直線上,使得好瓜、壞瓜各自的投影點儘可能接近,好瓜與壞瓜之間的投影點儘可能遠離。
2.演算法原理
3.演算法設計
① 根據LDA原理求解得到w,結合資料集得到LDA直線;
② 將每個樣本對映到LDA直線上,觀察分析結果。
4.關鍵程式碼
1 # 載入資料集 2 dataset = np.loadtxt('C:/Users/86185/PycharmProjects/ML1/watermelon_3a.csv', delimiter=",") 3 4 # 分離屬性值和標籤 5 X = dataset[:,1:3] 6 y = dataset[:,3]7 u = [] 8 for i in range(2): 9 u.append(np.mean(X[y==i],axis=0)) 10 11 m,n = np.shape(X) 12 Sw = np.zeros((n,n)) 13 for i in range(m): 14 x_temp = X[i].reshape(n, 1) # 行向量變為列向量 15 if y[i]==0: u_temp = u[0].reshape(n, 1) 16 if y[i]==1: u_temp = u[1].reshape(n, 1) 17 Sw +=np.dot(x_temp-u_temp, (x_temp-u_temp).T)18 19 Sw = np.mat(Sw) 20 # print(Sw) 21 Sw_inv = np.linalg.inv (Sw) 22 # print(Sw_inv) 23 w = np.dot(Sw_inv, (u[0]-u[1]).reshape(n,1)) 24 print(w)
先根據公式求得w
1 def GetPoint(point0, w): 2 k0 = w[1, 0]/w[0, 0] 3 k1 = w[0, 0]/w[1, 0] 4 x0 = point0[0] 5 y0 = point0[1] 6 x1 = (k0 * x0 - y0) / (k0 + k1)7 y1 = -k1 * x1 8 return x0, x1, y0, y1 9 10 f1 = plt.figure('first') 11 plt.xlim( -0.2, 1 ) # 設定座標軸的範圍 12 plt.ylim( -0.2, 0.6 ) 13 14 15 x = np.arange(-1, 3) 16 yy = -(w[0,0]/w[1,0])*x
做LDA直線yy;GetPoint()函式用來計算點到直線yy的投影
5.結果展示
根據執行結果顯示,沒有很明確的將好瓜與壞瓜區分開來,好瓜與壞瓜的投影點不夠遠離,壞瓜與壞瓜之間的投影點不夠聚集。