決策樹例項
阿新 • • 發佈:2018-11-28
import numpy as np from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import classification_report def load_data(): from sklearn.datasets import load_iris#鳶尾花資料集 from sklearn.preprocessing import StandardScaler#特徵縮放 from sklearn.model_selection import train_test_split#交叉驗證 data = load_iris() X = data.data y = data.target ss = StandardScaler() X = ss.fit_transform(X) x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) return x_train, y_train, x_test, y_test, data.feature_names def train(): x_train, y_train, x_test, y_test, _ = load_data() model = DecisionTreeClassifier() model.fit(x_train, y_train) y_pre = model.predict(x_test) print(model.score(x_test, y_test)) print(classification_report(y_test, y_pre)) def grid_search(): from sklearn.model_selection import GridSearchCV#網格搜尋 x_train, y_train, x_test, y_test, _ = load_data()#'_'值date.feature_names,我們不用這 #個引數 model = DecisionTreeClassifier() parameters = {'max_depth': np.arange(1, 50, 2)}#這裡超引數為決策樹中的深度 gs = GridSearchCV(model, parameters, verbose=5, cv=5) gs.fit(x_train, y_train) print('最佳模型:', gs.best_params_, gs.best_score_) y_pre = gs.predict(x_test) print(classification_report(y_test, y_pre)) def tree_visilize(): from sklearn import tree x_train, y_train, x_test, y_test, feature_names = load_data() print('類標:', np.unique(y_train))#np.unique為除去y_train中重複的數字 print('特徵名稱:', feature_names) model = DecisionTreeClassifier(max_depth=3) model.fit(x_train, y_train) print(model.score(x_test, y_test)) with open("allElectronicsData.dot", "w") as f: tree.export_graphviz(model, feature_names=feature_names, class_names=['A', 'B', 'C'], out_file=f) if __name__ == '__main__': train()#訓練決策樹,這裡的決策樹沒有剪枝,生成的是最大最複雜的樹,容易過擬合 grid_search()#訓練決策樹並找出最合適的超引數,這個引數是根據決策樹剪枝來的 tree_visilize()#決策樹的顯示,運用剪枝後的決策樹,剪枝後深度為3
其中,tree_visilize函式為顯示決策樹的函式,需要下載軟體:軟體為graphviz-2.38.msi 詳細操作
決策樹的顯示:
這個檔案是程式碼生成的,
顯示所生成的決策樹可以用裝好的gvedit.exe開啟,在這裡。
顯示:
決策樹詳細講解文章: