周志華《機器學習》習題4.3
為表4.3中資料生成一棵決策樹。
程式碼是在《機器學習實戰》的程式碼基礎上改良的,借用了numpy, pandas之後明顯簡化了程式碼。表4.3的資料特徵是離散屬性和連續屬性都有,問題就複雜在這裡。話不多說,看程式碼。
先定義幾個輔助函式,正常的思路是先想巨集觀演算法,然後需要什麼函式就定義什麼函式。
import math
import pandas as pd
import numpy as np
from treePlotter import createPlot
def entropy(data):
label_values = data[data.columns[-1 ]]
#Returns object containing counts of unique values.
counts = label_values.value_counts()
s = 0
for c in label_values.unique():
freq = float(counts[c])/len(label_values)
s -= freq*math.log(freq,2)
return s
def is_continuous(data,attr):
"""Check if attr is a continuous attribute"""
return data[attr].dtype == 'float64'
def split_points(data,attr):
"""Returns Ta,Equation(4.7),p.84"""
values = np.sort(data[attr].values)
return [(x+y)/2 for x,y in zip(values[:-1],values[1:])]
treePlotter是《實戰》裡的模組,用來把決策樹畫出來。這裡決策樹是用字典表示的,key可以表示樹的節點或分枝,表示節點的時候是屬性,表示分枝的時候是屬性值。value又是一個字典或字串,是字串的時候表示葉,也就是標記。這裡的data是pandas裡的DataFrame,形式上像一個表,對錶的常見操作它都可以方便的解決。命名習慣跟書上一致。
再繼續看怎麼計算資訊增益:
def discrete_gain(data,attr):
V = data[attr].unique()
s = 0
for v in V:
data_v = data[data[attr]== v]
s += float(len(data_v))/len(data)*entropy(data_v)
return (entropy(data) - s,None)
def continuous_gain(data,attr,points):
"""Equation(4.8),p.84,returns the max gain along with its splitting point"""
entD = entropy(data)
#gains is a list of pairs of the form (gain,t)
gains = []
for t in points:
d_plus = data[data[attr] > t]
d_minus = data[data[attr] <= t]
gain = entD - (float(len(d_plus))/len(data)*entropy(d_plus)+float(len(d_minus))/len(data)*entropy(d_minus))
gains.append((gain,t))
return max(gains)
離散屬性的資訊增益一目瞭然,最後返回的pair中的None是為了給後面的函式判斷之用,看到None就知道是離散屬性了。連續屬性的資訊增益的計算方法是對每個劃分點
然後就是統管的資訊增益函式:
def gain(data,attr):
if is_continuous(data,attr):
points = split_points(data,attr)
return continuous_gain(data,attr,points)
else:
return discrete_gain(data,attr)
還要用到一個眾數函式:
def majority(label_values):
counts = label_values.value_counts()
return counts.index[0]
我們的id3終於登場了:
def id3(data):
attrs = data.columns[:-1]
#attrWithGain is of the form [(attr,(gain,t))], t is None if attr is discrete
attrWithGain = [(a,gain(data,a)) for a in attrs]
attrWithGain.sort(key = lambda tup:tup[1][0],reverse = True)
return attrWithGain[0]
它對每個屬性都計算了資訊增益,最後返回資訊增益最大的那個屬性,連帶兩個附加值,形式是(attr,(gain,t))。
最後造樹:
def createTree(data,split_function):
label = data.columns[-1]
label_values = data[label]
#Stop when all classes are equal
if len(label_values.unique()) == 1:
return label_values.values[0]
#When no more features, or only one feature with same values, return majority
if data.shape[1] == 1 or (data.shape[1]==2 and len(data.T.ix[0].unique())==1):
return majority(label_values)
bestAttr,(g,t) = split_function(data)
#If bestAttr is discrete
if t is None:
#In this tree,a key is a node, the value is a list of trees,also a dictionary
myTree = {bestAttr:{}}
values = data[bestAttr].unique()
for v in values:
data_v = data[data[bestAttr]== v]
attrsAndLabel = data.columns.tolist()
attrsAndLabel.remove(bestAttr)
data_v = data_v[attrsAndLabel]
myTree[bestAttr][v] = createTree(data_v,split_function)
return myTree
#If bestAttr is continuous
else:
t = round(t,3)
node = bestAttr+'<='+str(t)
myTree = {node:{}}
values = ['yes','no']
for v in values:
data_v = data[data[bestAttr] <= t] if v == 'yes' else data[data[bestAttr] > t]
myTree[node][v] = createTree(data_v,split_function)
return myTree
這個我就不細說了,還得自己看。值得一提的是離散屬性的下一次遞迴把當前的離散值刪掉了,attrsAndLabel.remove(bestAttr),因為不允許這個屬性出現在後續的分枝中。然而連續屬性的時候,不刪,允許繼續出現。這個好理解,畢竟對連續屬性用的是二分法,可能需要多個二分才能把情況搞清。
測試一下:
if __name__ == "__main__":
f = pd.read_csv(filepath_or_buffer = 'dataset/watermelon3.0en.csv', sep = ',')
data = f[f.columns[1:]]
tree = createTree(data,id3)
print tree
createPlot(tree)
我把原表翻譯成英文了,因為中文的列印字典不顯示漢字,畫圖的時候甚至直接不能畫。
id,color,root,knock,texture,umbilical,touch,density,sugar content,good melon
1,green,curled up,cloudy,clear,concave,hard slip,0.697,0.46,yes
2,black,curled up,dull,clear,concave,hard slip,0.774,0.376,yes
3,black,curled up,cloudy,clear,concave,hard slip,0.634,0.264,yes
4,green,curled up,dull,clear,concave,hard slip,0.608,0.318,yes
5,pale,curled up,cloudy,clear,concave,hard slip,0.556,0.215,yes
6,green,slightly curled,cloudy,clear,slightly concave,soft sticky,0.403,0.237,yes
7,black,slightly curled,cloudy,slightly paste,slightly concave,soft sticky,0.481,0.149,yes
8,black,slightly curled,cloudy,clear,slightly concave,hard slip,0.437,0.211,yes
9,black,slightly curled,dull,slightly paste,slightly concave,hard slip,0.666,0.091,no
10,green,stiff,crisp,clear,flat,soft sticky,0.243,0.267,no
11,pale,stiff,crisp,fuzzy,flat,hard slip,0.245,0.057,no
12,pale,curled up,cloudy,fuzzy,flat,soft sticky,0.343,0.099,no
13,green,slightly curled,cloudy,slightly paste,concave,hard slip,0.639,0.161,no
14,pale,slightly curled,dull,slightly paste,concave,hard slip,0.657,0.198,no
15,black,slightly curled,cloudy,clear,slightly concave,soft sticky,0.36,0.37,no
16,pale,curled up,cloudy,fuzzy,flat,hard slip,0.593,0.042,no
17,green,curled up,dull,slightly paste,slightly concave,hard slip,0.719,0.103,no
treePlotter我就不放上來了,委屈大家看一下字典湊合下吧。
畫出來的樹跟書上圖4.8一樣:
把程式碼按照順序複製到編輯器,儲存下就可以運行了,記得吧treePlotter註釋掉。