1. 程式人生 > >機器學習實戰決策樹(一)——資訊增益與劃分資料集

機器學習實戰決策樹(一)——資訊增益與劃分資料集

from math import log
#計算給定的熵
def calcsahnnonent(dataset):
        numentries = len(dataset)     #計算例項的總數
        labelcounts ={}
        #建立一個數據字典
        for featvec in dataset:
            currentlabel = featvec[-1]    #鍵值是最後一列數值                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
            if currentlabel not in labelcounts.keys():  #為所有可能的分類建立字典。使用的是字典中key()方法
                labelcounts[currentlabel]= 0
            labelcounts[currentlabel] += 1
        shannonent = 0.0
        for key in labelcounts:
            prob = float(labelcounts[key]) / numentries
            shannonent -= prob * log(prob,2)    #以2為底求對數
        return shannonent
#建立資料集
def createdataset():
    dataset = [ [1, 1, 'yes'],
                [1, 1, 'yes'],
                [1, 0, 'no'],
                [0, 1, 'no'],
                [0, 1, 'no'] ]
    labels = ['no surfacing', 'flippers']
    return dataset, labels

#mydata,labels = createdataset()
#print('mydata:',mydata)
#mydata[0][-1] = 'maybe'
#calcsahnnonent(mydata)

#按照給定特徵劃分資料集
def splitdatset(dataset,axis,value):      #dataset:需要劃分的資料集;axis:劃分資料集的特徵
    retdataset = []             #為了不修改原資料集,建立一個新列表物件
    for featvec in dataset:  #抽取資料
        if featvec[axis] == value:
            reducefeatvec = featvec[:axis]  #去掉axis 的特徵
            #將符合條件的新增到返回的資料集
            reducefeatvec.extend(featvec[axis+1:]) 
            retdataset.append(reducefeatvec)
    return retdataset
#print(splitdatset(mydata,0,1))


#選擇最好的資料集劃分方式
def choosebestfeaturetosplit(dataset):
    numfeatures = len(dataset[0]) - 1
    baseentropy = calcsahnnonent(dataset)  #計算整個資料集的原始夏農熵
    bestinfogain = 0.0   #初始化返回引數
    bestfeature = -1  #最優特徵索引
    for i in range(numfeatures):   #遍歷所有的特徵
        featlist = [example[i] for example in dataset]   # 獲取dataSet的第i個所有特徵
        uniquevals = set(featlist)   #建立set集合{},元素不可重複
        newentropy = 0.0 #經驗條件熵初始化為0 
        #計算資訊增益
        for value in uniquevals:  #遍歷該特徵的所有取值
            subdataset = splitdatset(dataset, i, value)
            prob = len(subdataset) / float(len(dataset))
            newentropy += prob * calcsahnnonent(subdataset)
        infogain = baseentropy - newentropy
        #尋找最大資訊增益
        if (infogain > bestinfogain):
            bestinfogain = infogain
            bestfeature = i
    return bestfeature

choosebestfeaturetosplit(mydata)