西瓜書 課後習題4.4 基尼指數 未剪枝 預剪枝 後剪枝
阿新 • • 發佈:2018-11-27
import operator import csv import numpy as np def readDataset(filename): ''' 讀取資料 :param filename: 資料檔名,CSV格式 :return: 以列表形式返回資料列表和特徵列表 ''' with open(filename) as f: reader = csv.reader(f) header_row = next(reader) labels = header_row[1:7] dataset = [] for line in reader: tempVect = line[1:] dataset.append(tempVect) trainIndex = [1, 2, 3, 6, 7, 10, 14, 15, 16, 17] trainDataset = [] testDataset = [] for i in range(1, 18): if (i in trainIndex): trainDataset.append(dataset[i - 1]) else: testDataset.append(dataset[i - 1]) trainDataset.append(dataset[3]) # 為保持和書中結果相同,訓練集中增加第四條資料 return dataset, labels, trainDataset, testDataset def Gini(dataset): ''' 計算gini基尼值 :param dataset: 輸入資料集 :return: 返回基尼值gini ''' numdata = len(dataset) labels = {} for featVec in dataset: label = featVec[-1] if label not in labels.keys(): labels[label] = 0 labels[label] += 1 gini = 1 for lab in labels.keys(): prop = float(labels[lab]) / numdata gini -= prop ** 2 return gini def splitDataset(dataset, axis, value): ''' 對某個特徵進行劃分後的資料集 :param dataset: 資料集 :param axis: 劃分屬性的下標 :param value: 劃分屬性值 :return: 返回剩餘資料集 ''' restDataset = [] for featVec in dataset: if featVec[axis] == value: restFeatVec = featVec[:axis] restFeatVec.extend(featVec[axis + 1:]) restDataset.append(restFeatVec) return restDataset def bestFeatureSplit(dataset): ''' 最優屬性劃分 :param dataset: 輸入需要劃分的資料集 :return: 返回最優劃分屬性的下標 ''' numFeature = len(dataset[0]) - 1 bestGiniIndex = 10000 bestFeature = -1 for i in range(numFeature): featList = [example[i] for example in dataset] uniqueValue = set(featList) giniIndex = 0 for value in uniqueValue: subDataset = splitDataset(dataset, i, value) prop = len(subDataset) / float(len(dataset)) giniIndex += prop * Gini(subDataset) if (giniIndex < bestGiniIndex): bestGiniIndex = giniIndex bestFeature = i return bestFeature def majorClass(classList): ''' 對葉節點的分類結果進行劃分,投票原則 :param classList: 葉節點上的樣本數量 :return: 返回葉節點劃分結果 ''' classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 返回陣列 return sortedClassCount[0][0] def decideTreePredict(decideTree, testData, labelsFull): ''' 決策樹對測試資料進行結果預測 :param decideTree: 決策樹模型 :param testData: 測試資料 :param labelsFull: 特徵列表 :return: 返回預測結果 ''' firstFeat = list(decideTree.keys())[0] secDict = decideTree[firstFeat] featIndex = labelsFull.index(firstFeat) classLabel = None for value in secDict.keys(): if testData[featIndex] == value: if type(secDict[value]).__name__ == 'dict': classLabel = decideTreePredict(secDict[value], testData, labelsFull) else: classLabel = secDict[value] return classLabel def prevReduceBranch(bestFeatLabel, trainDataset, testDataset, labelsFull): classList = [example[-1] for example in trainDataset] bestFeatIndex = labelsFull.index(bestFeatLabel) trainDataValues = [example[bestFeatIndex] for example in trainDataset] uniqueValues = set(trainDataValues) error = 0 for value in uniqueValues: partClassList = [classList[i] for i in range(len(classList)) if trainDataValues[i] == value] major = majorClass(partClassList) for data in testDataset: if data[bestFeatIndex] == value and data[-1] != major: error += 1 # print('預剪枝繼續展開錯誤數:' + str(error)) return error def majorTest(major, testData): error = 0 for i in range(len(testData)): if major != testData[i][-1]: error += 1 # print('當前節點為結節點錯誤數: ' + str(error)) return error def postReduceBranch(subTree, testData, labelsFull): error = 0 for i in range(len(testData)): if decideTreePredict(subTree, testData[i], labelsFull) != testData[i][-1]: error += 1 # print('後剪枝保留子樹錯誤數: ' + str(error)) return error def createTree(trainDataset, labels, datasetFull, labelsFull, testDataset): ''' 遞迴建立決策樹 :param dataset: 資料集列表 :param labels: 標籤集列表 :param datasetFull: 資料集列表,再傳一次 :param labelsFull: 標籤集列表,再傳一次 :param testData: 測試資料集列表 :return: 返回決策樹字典 ''' classList = [example[-1] for example in trainDataset] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataset[0]) == 1: return (majorClass(classList)) bestFeat = bestFeatureSplit(trainDataset) bestFeatLabel = labels[bestFeat] # 預剪枝 # if prevReduceBranch(bestFeatLabel, trainDataset, testDataset, labelsFull) < majorTest( # majorClass(classList), # testDataset): # myTree = {bestFeatLabel: {}} # else: # return majorClass(classList) myTree = {bestFeatLabel: {}} del (labels[bestFeat]) featValues = [example[bestFeat] for example in trainDataset] uniqueVal = set(featValues) # 建立所有屬性標籤的所有值,以防漏掉某些取值,例如西瓜資料集2.0中的 色澤:淺白 bestFeatIndex = labelsFull.index(bestFeatLabel) featValuesFull = [example[bestFeatIndex] for example in datasetFull] uniqueValFull = set(featValuesFull) if uniqueVal == uniqueValFull: for value in uniqueVal: subLabels = labels[:] # 遞歸回退過程需要繼續使用標籤,所以前行過程標籤副本 myTree[bestFeatLabel][value] = createTree(splitDataset(trainDataset, bestFeat, value),subLabels, datasetFull,labelsFull,splitDataset(testDataset, bestFeat, value)) else: for value in uniqueVal: subLabels = labels[:] # 遞歸回退過程需要繼續使用標籤,所以前行過程標籤副本 myTree[bestFeatLabel][value] = createTree(splitDataset(trainDataset, bestFeat, value),subLabels, datasetFull,labelsFull,splitDataset(testDataset, bestFeat, value)) uniqueValFull.remove(value) for value in uniqueValFull: myTree[bestFeatLabel][value] = majorClass(classList) return myTree # 後剪枝 # print(myTree) # if postReduceBranch(myTree, testDataset, labelsFull) <= majorTest(majorClass(classList), testDataset): # return myTree # else: # return majorClass(classList) if __name__ == '__main__': filename = 'C:\\Users\\14399\\Desktop\\西瓜2.0.csv' dataset, labels, trainDataset, testDataset = readDataset(filename) datasetFull = trainDataset[:] labelsFull = labels[:] myTree = createTree(trainDataset, labels, datasetFull, labelsFull, testDataset) print(myTree)
未剪枝:{'臍部': {'凹陷': {'色澤': {'淺白': '否', '青綠': '是', '烏黑': '是'}}, '稍凹': {'根蒂': {'蜷縮': '否', '稍蜷': {'色澤': {'青綠': '是', '烏黑': {'紋理': {'稍糊': '是', '清晰': '否', '模糊': '是'}}, '淺白': '是'}}, '硬挺': '是'}}, '平坦': '否'}} 預剪枝:{'臍部': {'稍凹': '是', '平坦': '否', '凹陷': '是'}} 後剪枝: {'臍部': {'稍凹': {'根蒂': {'蜷縮': '否', '稍蜷': {'色澤': {'烏黑': '是', '青綠': '是', '淺白': '是'}}, '硬挺': '是'}}, '凹陷': '是', '平坦': '否'}}
西瓜2.0資料集:連結:https://pan.baidu.com/s/12aVngexje2RdizgOg1Fr0A 提取碼:uywy