1. 程式人生 > 其它 >決策樹-屬性選擇

決策樹-屬性選擇

現在,我們要做的是進行屬性(或者說特徵)的選擇

光看程式清單3-2,以及把陣列帶進去執行一遍可能也有點不清晰,最好先看一下西瓜書

然後意思是傳進去一個數據集,對於某一列(axis=0表示第1列),如果為0(value=0),那麼保留這一行但是不要這個屬性對應的值

import shannonEnt

dataSet = shannonEnt.dataSet
labelSet = shannonEnt.labelSet


def splitDataSet(dataSet, axis, value):
    featureDataSet = []
    for featureVec in dataSet:
        if featureVec[axis] == value:
            tempVec = featureVec[:axis]
            tempVec.extend(featureVec[axis + 1:])
            featureDataSet.append(tempVec)
    return featureDataSet


a = splitDataSet(dataSet, 0, 0)
b = splitDataSet(dataSet, 0, 1)
print(a)
print(b)

[[1, 'no'], [1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]

這裡a是對第0列取值為0的行進行了處理,b是對第0列取值為1的行進行了處理

這裡想更簡單點的話,用pd去掉某一列,然後再算比例也可以

接著書上使用ID3進行屬性選擇

思路如下:

  • 首先計算總的Ent,得到總共有2個屬性
  • 然後對於2個屬性進行遍歷,對於第1個屬性,得到其對應的屬性取值為[1, 1, 1, 0, 0]
  • 那麼對於剛剛得到的[1, 1, 1, 0, 0],我們知道有兩種取值,用set得到列表[1,0]
  • 從這兩個取值中再去用剛剛寫好的splitDataSet函式得到1對應的子集,以及0對應的子集
  • 這裡我們能知道1對應的子集個數為3,那麼由西瓜書公式4.2去計算$sigma$和Ent

結合起來程式碼如下

import shannonEnt

dataSet = shannonEnt.dataSet
labelSet = shannonEnt.labelSet


def splitDataSet(dataSet, axis, value):
    featureDataSet = []
    for featureVec in dataSet:
        if featureVec[axis] == value:
            tempVec = featureVec[:axis]
            tempVec.extend(featureVec[axis + 1:])
            featureDataSet.append(tempVec)
    return featureDataSet


# a = splitDataSet(dataSet, 0, 0)
# b = splitDataSet(dataSet, 0, 1)
# print(a)
# print(b)


def bestFeature(dataSet):
    # 獲得特徵(屬性)個數,這裡為2
    featureNum = len(dataSet[0]) - 1
    # 按西瓜書來看,計算Ent(D)
    totalEntropy = shannonEnt.calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(featureNum):
        # 匿名函式,[1, 1, 1, 0, 0],[1, 1, 0, 1, 1]
        # 即獲得每個屬性對應的列向量
        featList = [example[i] for example in dataSet]
        # print(featList)
        # 知道每個屬性可能有的取值
        uniqueVals = set(featList)
        # print(uniqueVals)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/len(dataSet)
            newEntropy += prob*shannonEnt.calcShannonEnt(subDataSet)
        infoGain = totalEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i

    return bestFeature

if __name__ == "__main__":
    result = bestFeature(dataSet)
    print(result)
0

這裡的內容,主要是要對ID3演算法比較熟,可以結合西瓜書多看幾遍