決策樹-屬性選擇
阿新 • • 發佈:2021-11-18
現在,我們要做的是進行屬性(或者說特徵)的選擇
光看程式清單3-2,以及把陣列帶進去執行一遍可能也有點不清晰,最好先看一下西瓜書
然後意思是傳進去一個數據集,對於某一列(axis=0表示第1列),如果為0(value=0),那麼保留這一行但是不要這個屬性對應的值
import shannonEnt dataSet = shannonEnt.dataSet labelSet = shannonEnt.labelSet def splitDataSet(dataSet, axis, value): featureDataSet = [] for featureVec in dataSet: if featureVec[axis] == value: tempVec = featureVec[:axis] tempVec.extend(featureVec[axis + 1:]) featureDataSet.append(tempVec) return featureDataSet a = splitDataSet(dataSet, 0, 0) b = splitDataSet(dataSet, 0, 1) print(a) print(b) [[1, 'no'], [1, 'no']] [[1, 'yes'], [1, 'yes'], [0, 'no']]
這裡a是對第0列取值為0的行進行了處理,b是對第0列取值為1的行進行了處理
這裡想更簡單點的話,用pd去掉某一列,然後再算比例也可以
接著書上使用ID3進行屬性選擇
思路如下:
- 首先計算總的Ent,得到總共有2個屬性
- 然後對於2個屬性進行遍歷,對於第1個屬性,得到其對應的屬性取值為[1, 1, 1, 0, 0]
- 那麼對於剛剛得到的[1, 1, 1, 0, 0],我們知道有兩種取值,用set得到列表[1,0]
- 從這兩個取值中再去用剛剛寫好的splitDataSet函式得到1對應的子集,以及0對應的子集
- 這裡我們能知道1對應的子集個數為3,那麼由西瓜書公式4.2去計算$sigma$和Ent
結合起來程式碼如下
import shannonEnt dataSet = shannonEnt.dataSet labelSet = shannonEnt.labelSet def splitDataSet(dataSet, axis, value): featureDataSet = [] for featureVec in dataSet: if featureVec[axis] == value: tempVec = featureVec[:axis] tempVec.extend(featureVec[axis + 1:]) featureDataSet.append(tempVec) return featureDataSet # a = splitDataSet(dataSet, 0, 0) # b = splitDataSet(dataSet, 0, 1) # print(a) # print(b) def bestFeature(dataSet): # 獲得特徵(屬性)個數,這裡為2 featureNum = len(dataSet[0]) - 1 # 按西瓜書來看,計算Ent(D) totalEntropy = shannonEnt.calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 for i in range(featureNum): # 匿名函式,[1, 1, 1, 0, 0],[1, 1, 0, 1, 1] # 即獲得每個屬性對應的列向量 featList = [example[i] for example in dataSet] # print(featList) # 知道每個屬性可能有的取值 uniqueVals = set(featList) # print(uniqueVals) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet,i,value) prob = len(subDataSet)/len(dataSet) newEntropy += prob*shannonEnt.calcShannonEnt(subDataSet) infoGain = totalEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i return bestFeature if __name__ == "__main__": result = bestFeature(dataSet) print(result)
0
這裡的內容,主要是要對ID3演算法比較熟,可以結合西瓜書多看幾遍