機器學習(二)決策樹
阿新 • • 發佈:2019-02-03
<span style="font-size:14px;">from math import log import operator #計算夏農熵 def calcShannonEnt(dataSet): numEntries=len(dataSet) labelCounts={} for featVec in dataSet:#為所有分類建立字典 currentLabel=featVec[-1]#the last if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 shannonEnt=0.0 for key in labelCounts: prob=float(labelCounts[key])/numEntries shannonEnt-=prob*log(prob,2) return shannonEnt #測試樣例 def createDataSet(): dataSet=[[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels=['no surfacin','flippers'] return dataSet,labels #劃分資料集 def splitDataSet(dataSet,axis,value): retDataSet=[] for featVec in dataSet: if featVec[axis]==value: reducedFeatVec=featVec[:axis]#0~axis reducedFeatVec.extend(featVec[axis+1:])#注意兩個函式區別 retDataSet.append(reducedFeatVec) return retDataSet #選擇最好的特徵 def chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1#不包括類標籤 baseEntropy=calcShannonEnt(dataSet) bestInfoGain=0.0;bestFeature=-1 for i in range(numFeatures): featList=[example[i] for example in dataSet] uniqueVals=set(featList) newEntropy=0.0 for value in uniqueVals:#計算資訊熵 subDataSet = splitDataSet(dataSet,i,value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob*calcShannonEnt(subDataSet) infoGain=baseEntropy-newEntropy if (infoGain>bestInfoGain):#找到最好的資訊增益 bestInfoGain = infoGain bestFeature = i return bestFeature #分支結束若葉子節點分類不唯一,採取投票方式 def majorityCht(classList): classCount={} for vote in classList: if vote not in classCount.key(): classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] #構造樹 def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0])==len(classList): return classList[0] if len(dataSet[0])==1:# no more features return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet)#bestFeat is the index of best feature bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del (labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueFeatVals = set(featValues) for value in uniqueFeatVals: subLabels=labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet\ (dataSet,bestFeat,value),subLabels) return myTree</span>
參考文獻