1. 程式人生 > 實用技巧 >機器學習演算法 —— 決策樹

機器學習演算法 —— 決策樹

演算法概述

  • 優點:計算複雜度不高,榆出結果易於理解,對中間值的缺失不敏感,可以處理不相關特徵資料。
  • 缺點:可能會產生過度匹配問題。
  • 適用資料範圍: 數值型標稱型

演算法流程

  • 資料
    • 樣本資料(多維多行資料 + 標籤)
    • 預測資料(多維一行資料)

  • 構建決策樹
  1. 遍歷資料集每一個 feature ,計算資訊熵的增益
  2. 選擇資訊增益最大的 feature 作為樹的節點
  3. 資料集按照 feature 值進行分組,對於每一個分組再次進行 1.2.3.遞迴計算
    • 遞迴出口
      • 只剩下一個 feature 無法再分
      • labels 都一樣,無論 feature 什麼樣都不影響 labels 了
  • 決策樹預測

    • 按照決策樹結構進行預測資料的判斷直到葉子結點
  • 熵的計算公式

程式碼示例

import collections
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties

def calcShannonEnt(dataCol):
    """
    資訊熵
    H = -pi * log2( pi )
    pi 為第 i 個值在所有值中出現的概率
    :param dataSet:
    :return:
    """
    labelNum = dataCol.shape[0]
    labelCounts = {}
    for label in dataCol:
        if label not in labelCounts.keys():
            labelCounts[label] = 0
        labelCounts[label] += 1
    entroy = 0.0
    for label, count in labelCounts.items():
        # 標籤值在所有值中的概率
        prob = count / labelNum
        entroy -= prob * math.log(prob,2)
    return  entroy

class DecisionTree:

    def __init__(self, dataSet, labels):
        self.tree = self.createTree(dataSet, labels)
        self.numLeafs = self.getNumLeafs(self.tree)
        self.deapth = self.getTreeDepth(self.tree)

    def splitDataSet(self, dataSet, axis, value):
        """
        劃分決策樹,抽取符合條件的資料
        :param dataSet:
        :param axis:
        :param value:
        :return:
        """
        reDataSet = []
        for featVec in dataSet:
            if featVec[axis] == value:
                reDataSet.append(np.hstack((featVec[:axis],featVec[axis+1:])).tolist())
        return np.array(reDataSet)

    def chooseBestFeatureToSplit(self, dataSet):
        """
        選取最優資料集劃分方式構建決策樹
        :param dataSet:
        :return:
        """
        numFeatures = len(dataSet[0]) - 1
        baseEntropy = calcShannonEnt(dataSet[:,-1])
        bestInfoGain, bestFeature = 0.0, -1
        for i in range(numFeatures):
            featList = dataSet[:,i]
            uniqueVals = set(featList)
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / len(dataSet)
                # 熵 = 選擇的決策佔比概率 *
                newEntropy += prob * calcShannonEnt(subDataSet[:,-1])
            infoGain = baseEntropy - newEntropy
            if(infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature

    def majorityCnt(self,classList):
        """
        找到出現次數最多的 class
        :param classList:
        :return:
        """
        return collections.Counter(classList).most_common(1)[0][0]
        # classCount = {}
        # for vote in classList:
        #     if vote not in classCount.keys():
        #         classCount[vote] = 0
        #     classCount[vote] += 1
        # sortedClassCount = sorted(classCount.items(), key=lambda x: x[1], reverse=True)
        # return sortedClassCount[0][0]

    def createTree(self, dataSet, labels):
        """
        構建決策樹
        :param dataSet:
        :param labels:
        :return:
        """
        classList = dataSet[:,-1]
        # 如果所有資料的 feature 都一樣,返回 feature
        if np.unique(classList).size == 1:
            return classList[0]
        # 如果只有一個 feature ,返回出現最多的 class
        if len(dataSet[0]) == 1:
            return self.majorityCnt(classList)
        # 選擇一個 feature,使得資訊增益最大
        bestFeat = self.chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        # 取出這個 feature 下的所有 class 作為分類標準
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]
            # 遞迴建立決策樹
            myTree[bestFeatLabel][value] = self.createTree(
                self.splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree

    def getNumLeafs(self, myTree):
        numLeafs = 0                                   #初始化葉子
        # firstStr = list(myTree.keys())[0]
        firstStr = next(iter(myTree))
        secondDict = myTree[firstStr]                 #獲取下一組字典
        for key in secondDict.keys():
            if isinstance(secondDict[key], dict):
                numLeafs += self.getNumLeafs(secondDict[key])
            else:   numLeafs +=1
        return numLeafs

    def getTreeDepth(self, myTree):
        maxDepth = 0                                          #初始化決策樹深度
        firstStr = next(iter(myTree))
        secondDict = myTree[firstStr]                         #獲取下一個字典
        for key in secondDict.keys():
            if isinstance(secondDict[key], dict):        #測試該結點是否為字典,如果不是字典,代表此結點為葉子結點
                thisDepth = 1 + self.getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth     #更新層數
        return maxDepth

    def classify(self, inputTree, labels, testVec):
        """
        分類預測
        :param inputTree: 決策樹
        :param labels: 資料標籤
        :param testVec: 測試資料
        :return:
        """
        firstStr = next(iter(inputTree))      #獲取決策樹結點
        secondDict = inputTree[firstStr]      #下一個字典
        featIndex = labels.index(firstStr)
        for key in secondDict.keys():
            if str(testVec[featIndex]) == key:
                if isinstance(secondDict[key],dict):
                    classLabel = self.classify(secondDict[key], labels, testVec)
                else: classLabel = secondDict[key]
        return classLabel

if __name__ == "__main__":
    dataSet = [[0, 0, 0, 0, 'no'],  # 資料集
               [0, 0, 0, 1, 'no'],
               [0, 1, 0, 1, 'yes'],
               [0, 1, 1, 0, 'yes'],
               [0, 0, 0, 0, 'no'],
               [1, 0, 0, 0, 'no'],
               [1, 0, 0, 1, 'no'],
               [1, 1, 1, 1, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [2, 0, 1, 2, 'yes'],
               [2, 0, 1, 1, 'yes'],
               [2, 1, 0, 1, 'yes'],
               [2, 1, 0, 2, 'yes'],
               [2, 0, 0, 0, 'no']]
    labels = ['年齡', '有工作', '有自己的房子', '信貸情況']  # 特徵標籤
    DTree = DecisionTree(np.array(dataSet), labels[:])
    print("tree:\t",DTree.tree)
    print("leaf nums:\t",DTree.numLeafs)
    print("deapth:\t",DTree.deapth)
    print("classify:\t", DTree.classify(DTree.tree, labels, [3,1,0,"yes"]))