1. 程式人生 > 實用技巧 >西瓜書4.3 編寫過程 決策樹

西瓜書4.3 編寫過程 決策樹

既然要建樹,就可以考慮使用自定義類。

然後考慮這個類需要哪些功能。首先,需要有根、葉的區分標記,然後是節點的屬性標記,

以及節點的屬性集,節點的葉子節點的記錄,節點的生成葉節點函式,節點的最優劃分屬性選擇函式,

節點包含的哪些訓練集樣本的記錄。

暫時就想到這些,那麼就先開始。

首先為了方便表示還是先把資料集全部用數字表示,變換規則依然是:

色澤:淺白 1,青綠 2,烏黑 3

根蒂:蜷縮 1,稍蜷 2,硬挺 3

敲聲:沉悶 1,濁響 2,清脆 3

紋理:模糊 1,稍糊 2,清晰 3

臍部:平坦 1,稍凹 2,凹陷 3

觸感:硬滑 1,軟粘 0

D=np.array([[2,1,2,3,3,1,0.697,0.406,1],\
   [
3,1,1,3,3,1,0.774,0.376,1],\ [3,1,2,3,3,1,0.634,0.264,1],\ [2,1,1,3,3,1,0.608,0.318,1],\ [1,1,2,3,3,1,0.556,0.215,1],\ [2,2,2,3,2,0,0.403,0.237,1],\ [3,2,2,2,2,0,0.481,0.149,1],\ [3,2,2,3,2,1,0.437,0.211,1],\ [3,2,1,2,2,1,0.666,0.091,0],\ [2,3,3,3,1,0,0.243,0.267,0],\ [1,3,3,1,1,1,0.245,0.057,0],\ [
1,1,2,1,1,0,0.343,0.099,0],\ [2,2,2,2,3,1,0.639,0.161,0],\ [1,2,1,2,3,1,0.657,0.198,0],\ [3,2,2,3,2,0,0.360,0.370,0],\ [1,1,2,1,1,1,0.593,0.042,0],\ [2,1,1,2,2,1,0.719,0.103,0]]) Y=['色澤','根蒂','敲聲','紋理','臍部','觸感','密度','含糖率','好瓜與否']

然後增加一個num陣列,用於表示離散屬性每種的屬性個數:

num=[3,3,3,3,3,2]
順便新增一個tree空列表用於儲存節點。
然後開始編寫類node(節點)

首先函式頭+初始化函式,將剛剛提到的屬性值都先定義好或者賦值:
def __init__(self,root,Dd,Aa,father,nu,floor):
self.root=root#樹根為0,中間為1,葉子為2
self.num=nu#在樹列表中的序號
self.attribu=None#葉節點屬性
self.danum=0#若為連續屬性劃分時的分類值
self.attri_set=Aa#屬性集
self.leaf=[]#子節點
self.sample=Dd#樣本集
self.father=father#父節點
self.gain_set=[]#各屬性的資訊熵
self.floor=floor#節點的層級
return


實際編寫的時候還是增加了一些屬性或者屬性集。所有屬性集、結點集或樣本集記錄的都是在對應全集(屬性集Y、結點集tree或樣本集D)中的序號。

然後就參考西瓜書上的圖4.2基本演算法,按步驟編寫成員函式。

首先就是如果該節點中所有樣本都屬於同一類別C(都是好瓜或者都是壞瓜),那就將該節點標記為葉子節點,節點型別標記為C:

def all_label_same(self):
k=D[self.sample[0]][8]
mark=1
for i in self.sample:
if(D[self.sample[i]][8]!=k):
mark=0
break
if(mark==1):
self.attribu = 8
self.danum=k
self.root = 2
return mark

用mark記錄比較結果,全同為1反之為0,若mark變0則跳出迴圈,若mark為1則修改節點屬性。

然後是第二步判斷是否節點屬性集為空或是節點的樣本的屬性完全一致,如果是完全一致則節點屬性值取節點樣本中較多的一類,

如果為空則與父節點屬性值相同,兩種情況均將該節點標記為葉子節點:

def all_b_att_same(self):#全空或全同
mark=0
if(len(self.attri_set)==0):
mark=1
self.danum=tree[self.father].coun_labe
self.attribu=8
self.root=2
elif(self.all_same()):
mark=1
self.danum=self.coun_labe()
self.attribu=8
self.root=2
return mark

同樣地使用mark標記。全空的話類別標記為父節點中樣本最多的標記。為了模組化和程式碼的可觀性,將判斷節點樣本是否屬性值完全一致與求樣本集中較多的標記類寫成了成員函式,詳細如下:

def all_same(self):
k=D[self.sample[0]]
mark=1
for i in range(len(self.sample)):
if(k!=D[self.sample[i]]):
mark=0
return mark

def coun_labe(self):#計算樣本中數量較多的類別
s=0
mark=0
for i in range(len(self.sample)):
s=s+D[self.sample[i]][8]
if(s>=(len(self.sample)/2)):
mark=1
return mark

因為樣本標記是使用0/1標記的,所以統計時只需要計算標記的值之和然後與總樣本數對比就可以了。

然後就是選擇最優劃分屬性的函式,由於屬性集中既有離散屬性又有連續屬性,所以要分開討論計算,

先寫出主體函式:

def best_divide(self):
ents=self.ent(self.sample)
gain=[]
for i in range(len(self.attri_set)):
order=self.attri_set[i]#計算第order個屬性的資訊增益
gainp = ents
if(order<=5):
for j in range(num[order]):
if(order!=5):
gainp=gainp-self.gain(order,j+1)
else:
gainp=gainp-self.gain(order,j)
gain.append([gainp,order,0])
else:
mx=self.maxgain(order)
gainp=gainp-mx[0]
gain.append([gainp,order,mx[1]])
maxx=0
for i in range(len(self.attri_set)):
if(gain[i][0]>gain[maxx][0]):
maxx=i
self.root=1
self.attribu=gain[maxx][1]
self.danum=gain[maxx][2]
return gain[maxx]

ent函式用於計算該節點的資訊熵,然後計算各個屬性的資訊增益,離散屬性和連續屬性分開討論並編寫不同函式計算,

由於連續屬性劃分之後還需要返回一個劃分標準(小於等於x),所以需要返回一個列表,同時為了方便,離散屬性也用相同形式儲存,

其中儲存x的地方放一個零,將這樣形式的包含最大資訊增益、對應屬性以及劃分標準(若為連續屬性)作為函式的返回值。

以下為ent函式的內容:

def ent(self,d):
s=0
for i in range(len(d)):
s=s+D[d[i]][8]
if(len(d)==0):
k=0
else:
k=s/len(d)
if(k==1 or k==0):
s=-math.log(1,2)
else:
s=-k*math.log(k,2)-(1-k)*math.log(1-k,2)
return s

簡單的遍歷計算,考慮到約定0log2 0=0,分類討論一下,下面都這麼做。然後是計算離散屬性資訊增益的函式:

def gain(self,i,j):
s=0
k=0
for p in range(len(self.sample)):
if(D[self.sample[p]][i]==j):
s+=1
if(D[self.sample[p]][8]==1):
k+=1
if(s==0):
k=0
else:
k=k/s
if(k==0 or k==1):
k=-math.log(1,2)
else:
k=-k*math.log(k,2)-(1-k)*math.log(1-k,2)
if(len(self.sample)==0):
k=0
else:
k=k*s/len(self.sample)
return k

在for迴圈中s用於計數節點樣本集中屬性i的值為j的個數,k用於計數滿足j條件中類別為1的個數,利用這兩個指就可以分別求出書上4.2中減號後

每一項的值。接下來是計算離散屬性的資訊增益的函式:

   def maxgain(self,a):
      lis=[]
      for i in range(self.sample):
         lis.append(D[self.sample[i]][a])
      lis.sort()
      long=len(lis)
      en=[]
      for i in range(long-1):
         t=(lis[i]+lis[i+1])/2
         entp=0
         num1=0#屬性a的值小於等於t的樣本中標記值為1的個數
         num2=0#屬性a的值大於t的樣本中標記值為1的個數
         sum=0#屬性a的值小於等於t的樣本數
         for j in range(self.sample):
            order=self.sample[j]
            if(D[order][a]<=t):
               sum+=1
               if(D[order][8]==1):
                  num1+=1
            else:
               if(D[order][8]==1):
                  num2+=1
         num1/=sum
         num2/=(long-sum)
         if(num1==0 or num1==1):
            entp-=np.log2(1)
         else:
            entp=entp-num1*np.log2(num1)-(1-num1)*np.log2(1-num1)
         if(num2==0 or num2==1):
            entp-=np.log2(1)
         else:
            entp=entp-num2*np.log2(num2)-(1-num2)*np.log2(1-num2)
         en.append(entp)
      s=0
      for i in range(en):
         if(en[i]<en[s]):
            s=i
      result=[en[s],(lis[s]+lis[s+1])/2]
      return result

先取出樣本中連續屬性的值,放入一個list中,sort之後迴圈依次相鄰數兩兩取平均值t作為劃分依據,然後遍歷,記錄屬性a的值小於等於t的樣本中標記值為1的個數、屬性a的值大於t的樣本中標記值為1的個數以及屬性a的值小於等於t的樣本數。依據這三項就可以計算每一種劃分的資訊增益大小,從結果中選出最小值返回。

然後就完成了最優屬性選擇,接下來就是建立子節點。

   def build_son(self):
      mark=0
      if(not self.all_label_same()):
         mark=0
      elif(not self.all_b_att_same()):
         mark=0
      else:
         mark=1
         bd=self.best_divide()
         if(bd[2]==0):
            pass_a=[]
            for i in range(self.attri_set):
               if(self.attri_set[i]!=bd[2]):
                  pass_a.append(self.attri_set[i])
            for i in range(num[bd[1]]):
               k=i
               pass_d = []
               if(bd[1]!=5):
                  k+=1
               for j in range(self.sample):
                  if(D[self.sample[j]][bd[1]]==k):
                     pass_d.append(self.sample[j])
               lon=len(tree)
               self.leaf.append([lon,k])
               tree.append(node(1,pass_d,pass_a,self.num,lon))
         else:
            pass_small=[]
            pass_bigger=[]
            for i in range(self.sample):
               if(D[self.sample[i]][bd[1]]<=bd[2]):
                  pass_small.append(self.sample[i])
               else:
                  pass_bigger.append(self.sample[i])
            lon=len(tree)
            self.leaf.append([lon,0])
            self.leaf.append([lon+1,1])
            tree.append(node(1,pass_small,self.attri_set,self.num,lon))
            tree.append(node(1,pass_bigger,self.attri_set,self.num,lon+1))
         #for i in range(self.leaf):
          #  tree[self.leaf[i]].build_son()
      return mark

將葉節點判斷內容也整合到子節點建立中,依然使用mark標誌以備不時之需。如果不是葉節點的話,就進行建立子節點操作,依照best_divide函式返回的值進行分類討論,需要注意的是第五項屬性的值是0或1,其他離散屬性的值是1或2或3,在迴圈對比屬性值時要根據情況改變值,觀察初始化函式:

def __init__(self,root,Dd,Aa,father,nu,floor):
self.root=root#樹根為0,中間為1,葉子為2
self.num=nu#在樹列表中的序號
self.attribu=8#葉節點屬性
self.danum=0#若為連續屬性劃分時的分類值
self.attri_set=Aa#屬性集
self.leaf=[]#子節點
self.sample=Dd#樣本集
self.father=father#父節點
self.gain_set=[]#各屬性的資訊熵
self.floor=floor
return

可以看到我們需要傳入的資料有節點的根屬性(樹根、中間或葉子,預設都為1),節點屬性集,節點樣本集和父節點的序號以及節點本身序號,其中在屬性集和樣本集的統計方法上離散屬性和連續屬性不一樣,要分開進行,父節點序號父節點中有記錄,子節點序號可以直接讀取當前tree的長度獲得,那麼主要的步驟就是屬性集和樣本集的構建。

首先是離散屬性,其屬性集只要父節點的屬性集排除掉自己的篩選屬性即可,而樣本集前四個屬性都是有三種取值,那麼對使值分別去0、1、2分別+1之後就可以得到屬性取值,篩選完一種取值後得到本種取值的樣本集就可以直接生成子節點,並將(子節點序號,子節點取值)資料對存入本節點的子節點集,第五個屬性只要省去屬性值+1的步驟即可。

然後是連續屬性,子節點的屬性集與父節點相同,而樣本集則根據best_divide返回的劃分值進行劃分,直接在一次遍歷中完成兩個樣本集的錄入然後一次性計入節點的子節點集。

在所有的子節點生成完畢後對每個子節點進行節點生成。

然後考慮決策樹分類完成之後的列印問題,為了使打印出來的樹具有較好的可讀性,在每個節點中加入層的屬性。

子節點的層數是父節點層數+1。

然後就是節點的列印函式:

def prin(self):
print('root:',self.root,' 劃分屬性(節點屬性):',Y[self.attribu],' ',self.danum,' floor:',self.floor)

主函式如下:

star=node(0,D,[0,1,2,3,4,5,6,7,8],0,0,0)
tree.append(star)
star.build_son()
flm=tree[len(tree)-1].danum+1
for i in range(flm):
   for j in range(tree):
      if(tree[j].danum==i):
         tree[j].prin()
   print('\n')

然後debug把問題都改掉之後發現遞迴爆棧了,所以把build函式裡的遞迴部分註釋掉,主函式更換如下:

star=node(0,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16],[0,1,2,3,4,5,6,7,8],0,0,0)
tree.append(star)
star.build_son()
tail=0

while(tail!=(len(tree)-1)):
   las=len(tree)-1
   k=tail
   for i in range(k+1,las+1):
      tree[i].prin()
      tree[i].build_son()

flm=tree[len(tree)-1].danum+1
for i in range(flm):
   for j in range(len(tree)):
      if(tree[j].danum==i):
         tree[j].prin()
   print('\n')

先給原始節點建立子節點以後,標記原始節點下一位和tree陣列最後一位,不斷給新增加的節點生成子節點直到沒有子節點生成,列印輸出

粗略估計了一下輸出可能會有14w以上的節點,而且也沒有做視覺化,原理上講程式應該基本實現了功能,但還是有待以後的完善,時間有限,暫且就這樣了。