機器學習實戰--決策樹
阿新 • • 發佈:2018-12-15
程式碼:
import numpy as np import operator #計算夏農熵,度量資料集的無序程度 def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCountes = {} for featureVect in dataSet: currentLable = featureVect[-1] labelCountes[currentLable] = labelCountes.get(currentLable,0)+1 shannonEnt = 0.0 for key in labelCountes: prob = float(labelCountes[key]/numEntries) shannonEnt -= prob*np.log2(prob) return shannonEnt #根據給定的特徵和該特徵的相應取值,劃分資料集 def splitDataSet(dataSet,axis,value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals =set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet,i,value) prob = len(subDataSet)/float(len(dataSet)) #在熵函式裡面已經加過負號了 newEntropy += prob*calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i return bestFeature #如果用完所有特徵後,還有葉子節點裡面沒有統一的分類,則返回最多數對應的分類 def maiorityCnt(classList): classCount = {} for vote in classList: classCount[vote] = classCount.get(vote,0) + 1 sortedClassCount = sorted(classCount.items(),operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] #建立樹函式程式碼,labels為特徵名 def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] #如果全都是一個分類了,到達了遞迴終點,返回 if classList.count(classList[0]) == len(classList): return classList[0] #遍歷完所有特徵後返回葉子節點中例項數目最多的類別,到達遞迴終點返回 if len(dataSet[0]) == 1: return maiorityCnt(classList) #chooseBestFeatureToSplit 返回最好的特徵對應於資料集的列下標 bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLable = labels[bestFeat] myTree = {bestFeatLable:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniquVals =set(featValues) for value in uniquVals: subLables = labels[:] myTree[bestFeatLable][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLables) return myTree def createDataSet(): dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels = ['水下能生存','有腳蹼'] return dataSet,labels if __name__ == '__main__': # dataset,labels = createDataSet() # shannonEnt = calcShannonEnt(dataset) # print(shannonEnt) # p = 1/5 # a = p*np.log2(p) # print(-a*5) # vocabset = set([]) # vocabset |= set(['a','b','c']) # print(vocabset) dataSet,labels = createDataSet() print(createTree(dataSet,labels))
執行結果: