1. 程式人生 > >決策樹例項

決策樹例項

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report


def load_data():
    from sklearn.datasets import load_iris#鳶尾花資料集
    from sklearn.preprocessing import StandardScaler#特徵縮放
    from sklearn.model_selection import train_test_split#交叉驗證
    data = load_iris()
    X = data.data
    y = data.target
    ss = StandardScaler()
    X = ss.fit_transform(X)
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
    return x_train, y_train, x_test, y_test, data.feature_names


def train():
    x_train, y_train, x_test, y_test, _ = load_data()
    model = DecisionTreeClassifier()
    model.fit(x_train, y_train)
    y_pre = model.predict(x_test)
    print(model.score(x_test, y_test))
    print(classification_report(y_test, y_pre))


def grid_search():
    from sklearn.model_selection import GridSearchCV#網格搜尋
    x_train, y_train, x_test, y_test, _ = load_data()#'_'值date.feature_names,我們不用這 
    #個引數
    model = DecisionTreeClassifier()
    parameters = {'max_depth': np.arange(1, 50, 2)}#這裡超引數為決策樹中的深度
    gs = GridSearchCV(model, parameters, verbose=5, cv=5)
    gs.fit(x_train, y_train)
    print('最佳模型:', gs.best_params_, gs.best_score_)
    y_pre = gs.predict(x_test)
    print(classification_report(y_test, y_pre))


def tree_visilize():
    from sklearn import tree
    x_train, y_train, x_test, y_test, feature_names = load_data()
    print('類標:', np.unique(y_train))#np.unique為除去y_train中重複的數字
    print('特徵名稱:', feature_names)
    model = DecisionTreeClassifier(max_depth=3)
    model.fit(x_train, y_train)
    print(model.score(x_test, y_test))
    with open("allElectronicsData.dot", "w") as f:
        tree.export_graphviz(model, feature_names=feature_names, class_names=['A', 'B', 'C'], out_file=f)


if __name__ == '__main__':
    train()#訓練決策樹,這裡的決策樹沒有剪枝,生成的是最大最複雜的樹,容易過擬合
    grid_search()#訓練決策樹並找出最合適的超引數,這個引數是根據決策樹剪枝來的
    tree_visilize()#決策樹的顯示,運用剪枝後的決策樹,剪枝後深度為3

其中,tree_visilize函式為顯示決策樹的函式,需要下載軟體:軟體為graphviz-2.38.msi 詳細操作

決策樹的顯示:

這個檔案是程式碼生成的,

顯示所生成的決策樹可以用裝好的gvedit.exe開啟,在這裡

顯示:

決策樹詳細講解文章:

文章1

文章2

文章3