1. 程式人生 > 其它 >從零開始寫程式碼 ID3決策樹Python

從零開始寫程式碼 ID3決策樹Python

視訊版地址B站:從零開始寫程式碼 Python ID3決策樹演算法分析與實現_嗶哩嗶哩_bilibili

程式碼如下:

# author:會武術之白貓
# date:2021-11-6
import math

def createDataSet():
    # dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    # labels = ['no sufacing', 'flippers']
    dataSet = [
        [
1,1,2,0,1,1,0,'感冒'], [2,0,3,2,0,2,2,'流感'], [3,0,0,1,1,1,1,'流感'], [0,0,1,1,1,0,1,'感冒'], [3,1,2,2,0,2,2,'流感'], [0,1,2,0,1,0,0,'感冒'], [2,0,2,2,0,2,2,'流感'], [0,1,3,0,0,1,1,'感冒']] labels = ['發冷','喉嚨痛','咳嗽','頭痛','鼻塞','疲勞','發燒'] return dataSet, labels def calcShannonEnt(dataSet): numEntries
= len(dataSet) # 為分類建立字典 labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts.setdefault(currentLabel, 0) labelCounts[currentLabel] += 1 # 計算夏農墒 shannonEnt = 0.0 for key in
labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt += prob * math.log2(1 / prob) return shannonEnt # 定義按照某個特徵進行劃分的函式 splitDataSet # 輸入三個變數(帶劃分資料集, 特徵,分類值) def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reduceFeatVec = featVec[:axis] reduceFeatVec.extend(featVec[axis + 1:]) retDataSet.append(reduceFeatVec) return retDataSet #返回不含劃分特徵的子集 # 定義按照最大資訊增益劃分資料的函式 def chooseBestFeatureToSplit(dataSet): numFeature = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInforGain = 0 bestFeature = -1 for i in range(numFeature): featList = [number[i] for number in dataSet] #得到某個特徵下所有值 uniqualVals = set(featList) #set無重複的屬性特徵值 newEntrogy = 0 #求和 for value in uniqualVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) #即p(t) newEntrogy += prob * calcShannonEnt(subDataSet) #對各子集求夏農墒 infoGain = baseEntropy - newEntrogy #計算資訊增益 #print(infoGain) # 最大資訊增益 if infoGain > bestInforGain: bestInforGain = infoGain bestFeature = i return bestFeature # 投票表決程式碼 def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount.setdefault(vote, 0) classCount[vote] += 1 sortedClassCount = sorted(classCount.items(), key=lambda i:i[1], reverse=True) return sortedClassCount[0][0] def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] # print(dataSet) # print(classList) # 類別相同,停止劃分 if classList.count(classList[0]) == len(classList): return classList[0] # 判斷是否遍歷完所有的特徵,是,返回個數最多的類別 if len(dataSet[0]) == 1: return majorityCnt(classList) #按照資訊增益最高選擇分類特徵屬性 bestFeat = chooseBestFeatureToSplit(dataSet) #分類編號 bestFeatLabel = labels[bestFeat] #該特徵的label myTree = {bestFeatLabel: {}} del (labels[bestFeat]) #移除該label featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] #子集合 #構建資料的子集合,並進行遞迴 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree def classify(inputTree, featLabels, testVec): """ :param inputTree: 決策樹 :param featLabels: 屬性特徵標籤 :param testVec: 測試資料 :return: 所屬分類 """ firstStr = list(inputTree.keys())[0] #樹的第一個屬性 sendDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) classLabel = None for key in sendDict.keys(): if testVec[featIndex] == key: if type(sendDict[key]).__name__ == 'dict': classLabel = classify(sendDict[key], featLabels, testVec) else: classLabel = sendDict[key] return classLabel if __name__ == '__main__': dataSet, labels = createDataSet() r = chooseBestFeatureToSplit(dataSet) #print(r) myTree = createTree(dataSet, labels) print(myTree) # --> {'no sufacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} res = classify(myTree, ['發冷','喉嚨痛','咳嗽','頭痛','鼻塞','疲勞','發燒'], [1,1,2,0,1,1,0]) print(res)