1. 程式人生 > >決策樹(Decision Tree) | 繪製決策樹

決策樹(Decision Tree) | 繪製決策樹

01 起

這篇文章中,我們講解了如何訓練決策樹,然後我們得到了一個字典巢狀格式的決策樹結果,這個結果不太直觀,不能一眼看著這顆“樹”的形狀、分支、屬性值等,怎麼辦呢?

本文就上文得到的決策樹,給出決策樹繪製函式,讓我們對我們訓練出的決策樹一目瞭然。

在繪製決策樹之後,我們會給出決策樹的使用方法:如何利用訓練好的決策樹,預測訓練資料的類別?

提示:不論是繪製還是使用決策樹,中心思想都是遞迴

02 決策樹寬度和深度

在繪製決策樹之前,我們需要知道利用python繪圖的部分知識,比如如何在圖中添加註解?

添加註解 annotate()

import matplotlib.pyplot as plt
%matplotlib inline 
%config InlineBackend.figure_format="retina" 
#設定出圖顯示中文
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False


decisionNode = dict(boxstyle="sawtooth", fc="0.8")
#decisionNode={boxstyle:"sawtooth",fc:"0.8"} #決策節點樣式
leafNode=dict(boxstyle="round4",fc="0.8") #葉節點樣式

#設定標註箭頭格式,->表示由註解指向外,<-表示指向註解,<->表示雙向箭頭
arrow_args=dict(arrowstyle="->") 
#arrow_args=dict(facecolor="blue",shrink=0.05) #另一種設定箭頭格式的方式

def plotNode(nodeText,centerPt,parentPt,nodeType):
    # nodeTxt為要顯示的文字,centerPt為文字的中心點,parentPt為指向文字的點 
    createPlot.ax1.annotate(nodeText,xytext=centerPt,textcoords="axes fraction",\
                            xy=parentPt,xycoords="axes fraction",\
                           va="bottom",ha="center",bbox=nodeType,arrowprops=arrow_args)

def createPlot():
    fig=plt.figure(figsize=(6,6),facecolor="white")
    fig.clf() #清空畫布
    
    # createPlot.ax1為全域性變數,繪製圖像的控制代碼,subplot為定義了一個繪圖
    #111表示figure中的圖有1行1列,即1個,最後的1代表第一個圖 
    # frameon表示是否繪製座標軸矩形 
    createPlot.ax1=plt.subplot(111,frameon=False)
    
    plotNode("決策節點",(0.8,0.4),(1.1,0.8),decisionNode)
    plotNode("葉節點",(0.5,0.2),(0.2,0.5),leafNode)
    plt.show()

執行createPlot()之後,得到這張圖,註解就新增好了,之後我們會利用這個方法新增決策樹的註解:

我們還需要知道如何計算一棵決策樹的寬度和深度?

計算決策樹寬度和深度 寬度:決策樹的葉節點個數 深度:決策樹最長分支的節點數

"""
輸入:字典巢狀格式的決策樹
輸出:該決策樹的葉節點數,相當於決策樹寬度(W)
"""
def countLeaf(desicionTree):
    cntLeaf=0
    firstFeatrue=list(desicionTree.keys())[0] #決策樹字典的第一個key是第一個最優特徵,為什麼要提取這個特徵呢,因為後面要遍歷該特徵的屬性值從而找個子樹
    subTree=desicionTree[firstFeatrue] #取節點key的value,即子樹
    
    for key in list(subTree.keys()): #遍歷最優特徵的各屬性值,每個屬性對應一個子樹,判斷子樹是否為葉節點
        if type(subTree[key]).__name__=="dict": 
            #如果當前屬性值對應的子樹型別為字典,說明這個節點不是葉節點,
            #那麼就遞迴呼叫自己,層層下探找到該通路葉節點,然後向上求和得到該通路葉節點數
            cntLeaf += countLeaf(subTree[key]) #遞迴
        else:
            cntLeaf += 1
    return cntLeaf

"""
輸入:字典巢狀格式的決策樹
輸出:該決策樹的深度(D)
"""
def countDepth(desicionTree):
    maxDepth=0
    firstFeatrue=list(desicionTree.keys())[0] #當前樹的最優特徵
    subTree=desicionTree[firstFeatrue]
    
    for key in list(subTree.keys()): #遍歷最優特徵的各屬性值,每個屬性對應一個子樹,判斷子樹是否為葉節點
        if type(subTree[key]).__name__=="dict": 
            thisDepth = 1+countDepth(subTree[key]) #這裡值得認真思考過程,作圖輔助思考
        else:
            thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth

我們拿上文訓練好的決策樹來測試一下,決策樹長這樣:

測試:

我們訓練好的決策樹寬度為8,深度為4

好了,下面我們可以進入繪製主函數了!

03 繪製決策樹

目前我們已經得到了決策樹的寬度和深度,還知道了如何在圖中添加註解,下面我們開始繪製決策樹,中心思想還是遞迴。

#自定義函式,在父子節點之間新增文字資訊,在決策樹中,相當於標註父結點特徵的屬性值
#cntPt是子節點座標,parentPt是父節點座標
def plotMidText(cntrPt,parentPt,nodeText):
    xMid=(parentPt[0]-cntrPt[0])/2+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,nodeText)
    

#自定義函式,是繪製決策樹的主力軍
def plotTree(decisionTree,parentPt,nodeText):
    cntLeafs=countLeaf(decisionTree)
    depth=countDepth(decisionTree)
    feature=list(decisionTree.keys())[0] #提取當前樹的第一個特徵
    subDict=decisionTree[feature] #提取該特徵的子集,該子集可能是一個新的字典,那麼就繼續遞迴呼叫子集繪製圖,否則該特徵對應的子集為葉節點
    
    #繪製特徵以及該特徵屬性
    cntrPt=(plotTree.xOff+(1.0+float(cntLeafs))/2.0/plotTree.totalW,plotTree.yOff) #根據整棵樹的寬度深度計算當前子節點的繪製座標
    plotMidText(cntrPt,parentPt,nodeText) #繪製屬性
    plotNode(feature,cntrPt,parentPt,decisionNode) #繪製特徵
    
    #第一個特徵繪製好之後,第二個特徵的y座標向下遞減(因為自頂向下繪製,yOff初始值為1.0,然後y遞減)
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    
    #遍歷當前樹的第一個特徵的各屬性值,判斷各屬性值對應的子資料集是否為葉節點,是則繪製葉節點,否則遞迴呼叫plotTree(),直到找到葉節點
    for key in subDict.keys(): 
        if type(subDict[key]).__name__=="dict":
            plotTree(subDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW #從左至右繪製,x初始值較小,然後x遞增
            plotNode(subDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    #在上述遞迴呼叫plotTree()的過程中,yOff會不斷被減小
    #當我們遍歷完該特徵的某屬性值(即找到該屬性分支的葉節點),開始對該特徵下一屬性值判斷時,若無下面語句,則該屬性對應的節點會從上一屬性最小的yOff開始合理
    #下面這行程式碼,作用是:在找到葉節點結束遞迴時,對yOff加值,保證下一次判斷時的y起點與本次初始y一致
    #若不理解,可以嘗試註釋掉下面這行語句,看看效果
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD 

    
#繪圖主函式
def createPlot(decisionTree):
    fig=plt.figure(figsize=(10,10),facecolor="white")
    fig.clf() #清空畫布
    axprops=dict(xticks=[],yticks=[]) #設定xy座標軸的刻度,在[]中填充座標軸刻度值,[]表示無刻度
    # createPlot.ax1為全域性變數,繪製圖像的控制代碼,subplot為定義了一個繪圖
    #111表示figure中的圖有1行1列,即1個,最後的1代表第一個圖 
    # frameon表示是否繪製座標軸矩形 
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    
    plotTree.totalW=float(countLeaf(decisionTree)) #全域性變數,整棵決策樹的寬度
    plotTree.totalD=float(countDepth(decisionTree))#全域性變數,整棵決策樹的深度
    plotTree.xOff=-0.5/plotTree.totalW
    plotTree.yOff=1.0
    
    plotTree(decisionTree,(0.5,1.0),'')
    plt.show()

下面我們用訓練好的決策樹測試一下繪製函式,激動人心的時刻到了

createPlot(my_tree)

這就是我們訓練的決策樹,每個節點代表一個特徵,節點連線的箭頭屬性代表該特徵的屬性值,比如特徵(紋理)=屬性值(清晰)

04 使用決策樹執行分類

目前,我們能夠訓練決策樹,能夠繪製決策樹了,但是決策樹主要的作用還沒有發揮出來。

分類決策樹,作用在於利用訓練好的決策樹,對測試集資料進行分類。

下面我們就展示如何利用我們訓練好的決策樹對測試集進行分類。

中心思想:比較某條測試資料與決策樹的數值,遞迴執行,直到找到某條測試資料的葉節點,然後該測試資料被分類為該葉節點分類

"""
輸入:訓練好的分類決策樹、該決策樹的特徵列表、某條測試資料各特徵屬性值(順序與決策樹特徵列表一致)
輸出:該條測試資料的分類
思路:
比較某條測試資料與決策樹的數值,遞迴執行,
直到找到某條測試資料的葉節點,然後該測試資料被分類為該葉節點分類
"""
def classifyDT(decisionTree,treeFeatures,testVec):
    firstFeature=list(decisionTree.keys())[0]
    subDict=decisionTree[firstFeature]
    
    #尋找當前樹最優特徵在特徵列表中的位置,便於定位測試資料集對於的特徵位置
    featureIndex=treeFeatures.index(firstFeature) 
    
    """判斷邏輯
    遍歷當前樹最優特徵各屬性值
    若測試資料對應位置的特徵值與key一致,就在這個分支上找下去
    若此特徵屬性值對應的分支不是葉節點,就遞迴呼叫自己,繼續在此分支上下探尋找葉節點
    若此特徵屬性值對應的分支是葉節點,就把該測試資料分類到該葉節點類別
    """
    for key in list(subDict.keys()): #遍歷當前樹最優特徵各屬性值
        if testVec[featureIndex]==key: #若測試資料對應位置的特徵值與key一致,就在這個分支上找下去
            if type(subDict[key]).__name__=="dict": #若此特徵屬性值對應的分支不是葉節點,就遞迴呼叫自己,繼續在此分支上下探尋找葉節點
                classLabel=classifyDT(subDict[key],treeFeatures,testVec)
            else:
                classLabel=subDict[key] #若此特徵屬性值對應的分支是葉節點,就把該測試資料分類到該葉節點類別
    return classLabel

我們來測試一下,classifyDT(),第一個引數代表訓練好的決策樹,第二個引數代表決策樹對應的特徵列表,第三個引數就是訓練集資料的特徵屬性了,這些屬性要對應第二個引數的特徵順序。

classifyDT(my_tree,\
           ['色澤', '根蒂', '敲聲', '紋理', '臍部', '觸感'],\
           ['青綠', '蜷縮', '濁響', '清晰', '凹陷', '軟粘'])

預測結果:

我們訓練的決策樹告訴我們,這顆待預測的西瓜,是一顆好瓜!

nice!

05 總結

本文給出了決策樹繪製方法和決策樹使用方法,中心思想都是遞迴。

本文訓練決策樹使用的演算法是ID3,資訊增益,這種演算法只能處理離散型資料,且只能用於分類。如果之後有精力,我們會給出另一種決策樹訓練演算法—CART演算法,這種方法可以處理連續型資料,且還可以用於迴歸。

敬請期待~~

06 參考

  1. 《統計學習方法》 李航 Chapter5
  2. 《機器學習實戰》 Peter Harrington Chapter3