python3實現決策樹(機器學習實戰)
阿新 • • 發佈:2018-12-24
from math import log
def calcShannonEnt(dataSet):#計算給定資料集的夏農熵
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
print(prob)
shannonEnt -= prob * log(prob, 2)
return shannonEnt
mydata = [[1, 1,'yes'], [1, 1,'yes'], [1, 0,'no'], [0, 1, 'no'], [0,1, 'no']]
print(calcShannonEnt(mydata))
def splitDataSet (dataSet, axis, value):#dataSet 是資料集,axis是第幾個特徵,value是該特徵的取值。
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#該函式是將資料集中第axis個特徵的值為value的資料提取出來。
def chooseBestFeatureToSplit(dataSet):#選擇最好的特徵劃分
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
mydata = [[1, 1,'yes'], [1, 1,'yes'], [1, 0,'no'], [0, 1, 'no'], [0,1, 'no']]
print(chooseBestFeatureToSplit(mydata));
def majorityCnt(classList):#如果剩下的資料中無特徵,則直接按最大百分比形成葉節點
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount += 1;
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgette(1), reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):#建立決策樹
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featvalue = [example[bestFeat] for example in dataSet]
uniqueVals = set(featvalue)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
labels = ['no surface', 'flippers']
print(createTree(mydata, labels))