決策樹分類鳶尾花資料集
阿新 • • 發佈:2018-12-08
import numpy as np import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt from sklearn.tree import DecisionTreeClassifier iris_feature = u'花萼長度', u'花萼寬度', u'花瓣長度', u'花瓣寬度', u'類別' path = '8.iris.data' # 資料檔案路徑 data = pd.read_csv(path, header=None) data.columns = iris_feature#將data的每一列的標籤設定為iris_feature,如果不設定就預設為0到n的數字 data['類別'] = pd.Categorical(data['類別']).codes#對每一個類別做統計進行打標籤賦予數字 x_train = data[['花萼長度', '花瓣長度']] y_train = data['類別'] model = DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3) model.fit(x_train, y_train) N, M = 500, 500 # 橫縱各取樣多少個值 x1_min, x2_min = x_train.min(axis=0) x1_max, x2_max = x_train.max(axis=0) t1 = np.linspace(x1_min, x1_max, N) t2 = np.linspace(x2_min, x2_max, M) x1, x2 = np.meshgrid(t1, t2) # 生成網格取樣點 x_show = np.stack((x1.flat, x2.flat), axis=1) # 測試點 y_predict = model.predict(x_show) mpl.rcParams['font.sans-serif'] = ['SimHei'] mpl.rcParams['axes.unicode_minus'] = False cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF']) cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b']) plt.xlim(x1_min, x1_max) plt.ylim(x2_min, x2_max) plt.pcolormesh(x1, x2, y_predict.reshape(x1.shape), cmap=cm_light) plt.scatter(x_train['花萼長度'], x_train['花瓣長度'], c=y_train, cmap=cm_dark, marker='o', edgecolors='k') plt.xlabel('花萼長度') plt.ylabel('花瓣長度') plt.title('鳶尾花分類') plt.grid(True, ls=':') plt.savefig('1.png') plt.show()