1. 程式人生 > >機器學習(二)決策樹

機器學習(二)決策樹

<span style="font-size:14px;">from math import log
import operator
#計算夏農熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    for featVec in dataSet:#為所有分類建立字典
        currentLabel=featVec[-1]#the last
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt
#測試樣例
def createDataSet():
    dataSet=[[1,1,'yes'],
             [1,1,'yes'],
             [1,0,'no'],
             [0,1,'no'],
             [0,1,'no']]
    labels=['no surfacin','flippers']
    return dataSet,labels
#劃分資料集
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]#0~axis
            reducedFeatVec.extend(featVec[axis+1:])#注意兩個函式區別
            retDataSet.append(reducedFeatVec)
    return retDataSet
#選擇最好的特徵
def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1#不包括類標籤
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0;bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals=set(featList)
        newEntropy=0.0
        for value in uniqueVals:#計算資訊熵
            subDataSet = splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        if (infoGain>bestInfoGain):#找到最好的資訊增益
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
#分支結束若葉子節點分類不唯一,採取投票方式
def majorityCht(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.key():
            classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.iteritems(),
                            key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
#構造樹
def createTree(dataSet,labels):
	classList = [example[-1] for example in dataSet]
	if classList.count(classList[0])==len(classList):
		return classList[0]
	if len(dataSet[0])==1:# no more features
		return majorityCnt(classList)
	bestFeat = chooseBestFeatureToSplit(dataSet)#bestFeat is the index of best feature
	bestFeatLabel = labels[bestFeat]
	myTree = {bestFeatLabel:{}}
	del (labels[bestFeat])
	featValues = [example[bestFeat] for example in dataSet]
	uniqueFeatVals = set(featValues)
	for value in uniqueFeatVals:
		subLabels=labels[:]
		myTree[bestFeatLabel][value] = createTree(splitDataSet\
                                                          (dataSet,bestFeat,value),subLabels)
	return myTree</span>

參考文獻