視覺化決策樹之Python實現
阿新 • • 發佈:2019-02-02
決策樹(Decision Tree)是在已知各種情況發生概率的基礎上,通過構成決策樹來求取淨現值的期望值大於等於零的概率,評價專案風險,判斷其可行性的決策分析方法,是直觀運用概率分析的一種圖解法。一些基礎原理這裡就不再一一介紹了,直接進入今天的主題,如何視覺化決策樹。
本篇使用klearn來實現決策樹的過程,下面是詳細講解:
首先匯入必要的包:
然後,匯入資料集。我用的是kaggle上的蘑菇資料集,這是一個經典的決策樹資料集,非常適合決策樹,下面我們就會知道。import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import GridSearchCV from sklearn.metrics import accuracy_score from sklearn.metrics import roc_auc_score
data = pd.read_csv("mushrooms.csv")
data.head()
先初步認識一下資料集:
可以看出這是一個分類變數的資料集。然後,我們就要將它變成數值變數,好利於下面的建模。
from sklearn.preprocessing import LabelEncoder labelencoder = LabelEncoder() for col in data.columns: data[col] = labelencoder.fit_transform(data[col]) data.head()
之後,我們來看看資料的大小:
data.shape
(8124, 23)資料準備後,我們開始提取訓練集與測試集。
y = data['class']
X = data.drop('class', axis=1)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, train_size=0.8)
columns = X_train.columns
接著標準化訓練集
接著,構建決策樹模型# 資料標準化 from sklearn.preprocessing import StandardScaler ss_X = StandardScaler() ss_y = StandardScaler() X_train = ss_X.fit_transform(X_train) X_test = ss_X.transform(X_test)
from sklearn.tree import DecisionTreeClassifier
model_tree = DecisionTreeClassifier()
model_tree.fit(X_train, y_train)
評價模型準確性
y_prob = model_tree.predict_proba(X_test)[:,1]
y_pred = np.where(y_prob > 0.5, 1, 0)
model_tree.score(X_test, y_pred)
可以得到結果:1.
說明決策樹非常吻合此資料集。
最後,完成決策樹的視覺化
# 視覺化樹圖
data_ = pd.read_csv("mushrooms.csv")
data_feature_name = data_.columns[1:]
data_target_name = np.unique(data_["class"])
import graphviz
import pydotplus
from sklearn import tree
from IPython.display import Image
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
dot_tree = tree.export_graphviz(model_tree,out_file=None,feature_names=data_feature_name,class_names=data_target_name,filled=True, rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_tree)
img = Image(graph.create_png())
graph.write_png("out.png")
注意:graphviz包不僅需要使用pip install graphviz安裝還需要單獨安裝。使用時,還需要引入graphviz絕對路徑。
參考:http://scikit-learn.org/stable/modules/tree.html
graphviz-2.38.msi安裝包下載:http://www.graphviz.org/Download_windows.php
資料集:http://download.csdn.net/download/llh_1178/10115766