1. 程式人生 > >視覺化決策樹之Python實現

視覺化決策樹之Python實現

決策樹(Decision Tree)是在已知各種情況發生概率的基礎上,通過構成決策樹來求取淨現值的期望值大於等於零的概率,評價專案風險,判斷其可行性的決策分析方法,是直觀運用概率分析的一種圖解法。一些基礎原理這裡就不再一一介紹了,直接進入今天的主題,如何視覺化決策樹。

本篇使用klearn來實現決策樹的過程,下面是詳細講解:

首先匯入必要的包:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
然後,匯入資料集。我用的是kaggle上的蘑菇資料集,這是一個經典的決策樹資料集,非常適合決策樹,下面我們就會知道。
data = pd.read_csv("mushrooms.csv")
data.head()
先初步認識一下資料集:


可以看出這是一個分類變數的資料集。然後,我們就要將它變成數值變數,好利於下面的建模。

from sklearn.preprocessing import LabelEncoder
labelencoder = LabelEncoder()
for col in data.columns:
    data[col] = labelencoder.fit_transform(data[col])
data.head()

之後,我們來看看資料的大小:
data.shape
(8124, 23)
資料準備後,我們開始提取訓練集與測試集。
y = data['class']
X = data.drop('class', axis=1)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, train_size=0.8)
columns = X_train.columns
接著標準化訓練集
# 資料標準化
from sklearn.preprocessing import StandardScaler
ss_X = StandardScaler()
ss_y = StandardScaler()
X_train = ss_X.fit_transform(X_train)
X_test = ss_X.transform(X_test)
接著,構建決策樹模型
from sklearn.tree import DecisionTreeClassifier
model_tree = DecisionTreeClassifier()
model_tree.fit(X_train, y_train)
評價模型準確性
y_prob = model_tree.predict_proba(X_test)[:,1]
y_pred = np.where(y_prob > 0.5, 1, 0)
model_tree.score(X_test, y_pred)
可以得到結果:1. 

說明決策樹非常吻合此資料集。

最後,完成決策樹的視覺化

# 視覺化樹圖
data_ = pd.read_csv("mushrooms.csv")
data_feature_name = data_.columns[1:]
data_target_name = np.unique(data_["class"])
import graphviz
import pydotplus
from sklearn import tree
from IPython.display import Image
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
dot_tree = tree.export_graphviz(model_tree,out_file=None,feature_names=data_feature_name,class_names=data_target_name,filled=True, rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_tree)
img = Image(graph.create_png())
graph.write_png("out.png")




注意:graphviz包不僅需要使用pip install graphviz安裝還需要單獨安裝。使用時,還需要引入graphviz絕對路徑。

參考:http://scikit-learn.org/stable/modules/tree.html

graphviz-2.38.msi安裝包下載:http://www.graphviz.org/Download_windows.php

資料集:http://download.csdn.net/download/llh_1178/10115766