1. 程式人生 > >機器學習實戰——決策樹

機器學習實戰——決策樹

本文記錄的是《機器學習實戰》和《統計學習方法》中決策樹的原理和實現。

1、決策樹

定義:分類決策樹模型是一種描述對例項進行分類的樹形結構。決策樹由節點(node)和有向邊(directed edge)組成。節點有兩種型別:內部結點和葉結點,內部結點表示一個特徵或者屬性,葉結點表示一個類。
用決策樹進行分類,從根結點開始,對例項的某一特徵進行測試,根據測試結構,將例項分配到其子結點;這時,每一個子結點對用著特徵的一個取值,如此遞迴的對例項進行測試並分配,直至到達葉結點。最後將例項分到葉結點的類中。
這裡寫圖片描述
決策樹的一般流程:
(1)收集資料:可以使用任何方法。
(2)準備資料:樹構造演算法只適用於標稱型資料,因此數值型資料必須離散化。
(3)分析資料:可以使用任何方法,構造樹完成之後,我們應該檢查圖形是否符合預期。
(4)訓練演算法:構造樹的資料結構。
(5)測試演算法:使用經驗樹計算錯誤率。
(6)使用演算法:此步驟可以適用於任何監督學習演算法,而使用決策樹可以更好地理解資料
的內在含義。
目前常用的決策樹演算法有ID3演算法、改進的C4.5演算法和CART演算法。

2、ID3 演算法原理和實現

ID3演算法最早是由羅斯昆(J. Ross Quinlan)於1975年在悉尼大學提出的一種分類預測演算法,演算法以資訊理論為基礎,其核心是“資訊熵”。ID3演算法通過計算每個屬性的資訊增益,認為資訊增益高的是好屬性,每次劃分選取資訊增益最高的屬性為劃分標準,重複這個過程,直至生成一個能完美分類訓練樣例的決策樹。
這裡寫圖片描述
《統計學習方法》中該部分的描述:
這裡寫圖片描述
這裡寫圖片描述
下面是用python具體的實現:
1.首先建立一個數據集:

# -*- coding: utf-8 -*-
from math import log
import operator
import
treePlotter # 建立資料集 def createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] # change to discrete values return dataSet, labels

2.計算夏農熵:

# 計算夏農熵
def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: # the the number of unique elements and their occurance currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 #print(labelCounts) shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob * log(prob, 2) # log base 2 return shannonEnt myDat, labels = createDataSet() print(calcShannonEnt(myDat)) myDat[0][-1] = 'maybe' print(myDat, labels) print('Ent changed: ', calcShannonEnt(myDat))

輸出為:
這裡寫圖片描述
可以看出,在myDat[0][-1]更改之後,熵變大了。

3.分離資料

# 分離資料
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:              # 判斷axis列的值是否為value
            reducedFeatVec = featVec[:axis]     # [:axis]表示前axis列,即若axis為2,就是取featVec的前axis列
            print(reducedFeatVec)               # [axis+1:]表示從跳過axis+1行,取接下來的資料
            reducedFeatVec.extend(featVec[axis+1:])  # 列表擴充套件
            print(reducedFeatVec)
            retDataSet.append(reducedFeatVec)
            print(retDataSet)
    return retDataSet
print 'splitDataSet is :', splitDataSet(myDat, 1, 1)

輸出結果:
這裡寫圖片描述
以axis = 1為基準即第二列,刪除了value != 1的資料,並重新組合。

4.選擇最優特徵分離:

# 選擇最優特徵進行分離
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      # the last column is used for the labels
    # print numFeatures
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):        
    # 計算每一特徵對應的熵 ,然後:iterate over all the features
        featList = [example[i] for example in dataSet] 
        #create a list of all the examples of this feature
        #print 'featList:', featList
        uniqueVals = set(featList)       # get a set of unique values
        # print(uniqueVals)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            # print (subDataSet)
            prob = len(subDataSet)/float(len(dataSet))    # 計運算元資料集在總的資料集中的比值
            newEntropy += prob * calcShannonEnt(subDataSet)
        # print(newEntropy)
        infoGain = baseEntropy - newEntropy    
         #calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature                     
     # 選出最優的特徵,並返回特徵角標 returns an integer
print 'the best feature is:', chooseBestFeatureToSplit(myDat)

輸出結果為:
輸出結果
即最優特徵為0對應的。

5. 統計出現次數最多的分類名稱

# 統計出現次數最多的分類名稱
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    # 使用程式第二行匯入運算子模組的itemgetter方法,按照第二個元素次序進行排序,逆序 :從大到小
    return sortedClassCount[0][0]

6.建立決策樹

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList): 
        return classList[0]  # stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]  # 抽取最優特徵下的數值,重新組合成list,
    # print "featValues:", featValues
    uniqueVals = set(featValues)
    # print "uniqueVals:", uniqueVals
    for value in uniqueVals:
        subLabels = labels[:]       # copy all of labels, so trees don't mess up(搞錯) existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        #print myTree
    return myTree                            
print('createTree is :', createTree(myDat, labels))

myDat, labels = createDataSet()
mytree = createTree(myDat, labels)

程式碼輸出:
這裡寫圖片描述
7.決策樹模型:


def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]  # 找到輸入樹當中鍵值[0]位置的值給firstStr
    #print 'firstStr is:', firstStr
    secondDict = inputTree[firstStr]
    #print 'secondDict is:', secondDict
    featIndex = featLabels.index(firstStr)  # index方法查詢當前列表中第一個匹配firstStr變數的元素的索引
    #print 'featIndex', featIndex
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            # 判斷節點是否為字典來以此判斷是否為葉子節點
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

myDat, labels = createDataSet()
mytree = createTree(myDat, labels)
print mytree
myDat, labels = createDataSet()
classlabel_1 = classify(mytree, labels, [1, 0])
print '[1,0] is :', classlabel_1
classlabel_2 = classify(mytree, labels, [1, 1])
print '[1,1] is:', classlabel_2

程式碼輸出:
這裡寫圖片描述
8.樹的儲存和讀取

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()

#讀取樹
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

# 測試並列印
storeTree(mytree, 'classifierstorage.txt')
print grabTree('classifierstorage.txt')

3.使用決策樹預測隱形眼鏡的型別

# 使用決策樹預測隱形眼鏡的型別
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
print lenses
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRata']
lensesTree = createTree(lenses, lensesLabels)
print lensesTree
treePlotter.createPlot(lensesTree)

程式碼輸出:
這裡寫圖片描述

4.繪製樹圖形

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]
mytree = retrieveTree(0)
createPlot(mytree)

程式碼輸出:

這裡寫圖片描述
這裡寫圖片描述

4.小結

優點:計算複雜度不高,輸出結果易於理解,對中間值的缺失不敏感,可以處理不相關特徵資料。
缺點:可能會產生過度匹配問題。
適用資料型別:數值型和標稱型
ID3演算法只有樹的生成,所以該演算法生成的樹很容易過擬合,後面的C4.5和CART,以及決策樹的剪枝會在詳細說明。