西瓜數 課後習題3.5 線性判別分析
阿新 • • 發佈:2018-11-27
import csv import numpy as np import matplotlib.pyplot as plt def readData(filename): """ 讀取資料 :param filename: csv格式資料集 :return: X:list with shape[N,d], X1:shape[8,2], X2:shape[9,2] y:list with shape[N,], y1:shape[8,1], y2:shape[9,1] """ X1, X2, y1, y2 = [], [], [], [] density, sugar = [], [] with open(filename) as f: reader = csv.reader(f) head_row = next(reader) for line in reader: if line[9] == '是': X1.append([float(line[7]), float(line[8])]) y1.append([float(line[10])]) density.append(float(line[7])) sugar.append(float(line[8])) if line[9] == "否": X2.append([float(line[7]), float(line[8])]) y2.append([float(line[10])]) density.append(float(line[7])) sugar.append(float(line[8])) return X1, X2, y1, y2, density, sugar def LDA(X1, X2): """ 線性判別分析 :param X1: np.array with shape[8,2] :param X2: np.array with shape[9,2] :return: omega: np.array with shape[2,1], LDA最優化引數 """ mean1 = np.mean(X1, axis=0, keepdims=True) # shape[1,d] mean2 = np.mean(X2, axis=0, keepdims=True) Sw = (X1 - mean1).T.dot(X1 - mean1) + (X2 - mean2).T.dot(X2 - mean2) # shape[d,d] omega = np.linalg.inv(Sw).dot((mean1 - mean2).T) # shape[d,1] return omega if __name__ == '__main__': dataset = "C:\\Users\\14399\\Desktop\\西瓜3.0.csv" X1, X2, y1, y2, density, sugar = readData(dataset) # 視覺化 plt.plot(density[:8], sugar[:8], 'r+') plt.plot(density[8:], sugar[8:], 'bo') # LDA X1 = np.array(X1) X2 = np.array(X2) y1 = np.array(y1) y2 = np.array(y2) omega = LDA(X1, X2) # 畫圖 lda_left = 0 lda_right = -(omega[0, 0] * 0.9) / omega[1, 0] plt.plot([0, 0.9], [lda_left, lda_right], 'g') plt.xlabel('density') plt.ylabel('sugar') plt.title('LDA') plt.show()
結果:
西瓜3.0資料集:連結:https://pan.baidu.com/s/1RXTUG9gP1Jn9HKFCiEzOlA 密碼:3h6n
參考資料:https://blog.csdn.net/victoriaw/article/details/77989610