決策樹 ID3學習筆記
最近開始看機器學習方面的知識,決策樹(DT)主要有三種演算法 ID3(Iterative Dichotomiser 3)、 C4.5、CART哈。
包括的知識有,資訊熵、資訊增益、資訊增益率、基尼指數概念,另外還有預剪枝、後剪枝等。
決策樹演算法的優點:
1:易於理解,使用白盒模型,相比之下,在一個黑盒子模型(例如人工神經網路),結果可能更難以解釋
2:需要準備的資料量不大
3:能夠處理數字和資料的類別(需要做相應的轉變),而其他演算法分析的資料集往往是隻有一種型別的變數
4:能夠處理多輸出的問題
決策樹演算法的缺點:
1:決策樹演算法學習者可以建立複雜的樹,容易過擬合,為了避免這種問題,出現了剪枝的概念,即設定一個葉子結點所需要的最小數目或者設定樹的最大深度
2:決策樹的結果可能是不穩定的,因為在資料中一個很小的變化可能導致生成一個完全不同的樹,這個問題可以通過使用整合決策樹來解決
3:實際決策樹學習演算法是基於啟發式演算法,如貪婪演算法,尋求在每個節點上的區域性最優決策。這樣的演算法不能保證返回全域性最優決策樹。
4:決策樹學習者很可能在某些類占主導地位時建立有有偏異的樹,因此建議用平衡的資料訓練決策樹
剪枝:
預剪枝:在選擇分裂屬性的時候,計算增加該屬性,驗證集的準確率是否能夠提高,如果可以提高,則增加該屬性;否則不增加
後剪枝:按照演算法(c4.5、cart 等)生成決策樹,然後從下往上對每個分支屬性進行判斷,如果去掉該屬性,驗證集的準確率是否降低,如果不降低就保留,否則去掉。
一般情況下,預剪枝容易underfitting,而後剪枝不存在這種問題,但是後剪枝計算量通常會比預剪枝大。
ID3演算法最早是由羅斯昆(J. Ross Quinlan)於1975年在悉尼大學提出的一種分類預測演算法,演算法以資訊理論為基礎,其核心是“資訊熵”。ID3演算法通過計算每個屬性的資訊增益,認為資訊增益高的是好屬性,每次劃分選取資訊增益最高的屬性為劃分標準,重複這個過程,直至生成一個能完美分類訓練樣例的決策樹。首先介紹兩個概念:
理解了上面兩個概念就好辦了,下面是一個具體的例子,加入根據 outlook,temp,hum,windy 屬性來決定去不去paly。
直接上程式碼,寫的比較搓,就是個示例,哈:
import copy, numpy as np import sys import pdb #pdb.set_trace() '''訓練樣本,前兩列對應是屬性值,最後一列表示分類,下面的值分別對應 Outlook, Temperature, Humidity, Windy, Play''' train_data = [['Outlook', 'Temperature', 'Humidity', 'Windy', 'Play'], ['sunny', 'hot', 'high', 'false', 'no'], ['sunny', 'hot', 'high', 'true', 'no'], ['overcast', 'hot', 'high', 'false', 'yes'], ['rain', 'mild', 'high', 'false', 'yes'], ['rain', 'cool', 'normal', 'false', 'yes'], ['rain', 'cool', 'normal', 'true', 'no'], ['overcast', 'cool', 'normal', 'true', 'yes'], ['sunny', 'mild', 'high', 'false', 'no'], ['sunny', 'cool', 'normal', 'false', 'yes'], ['rain', 'mild', 'normal', 'false', 'yes'], ['sunny', 'mild', 'normal', 'true', 'yes'], ['overcast', 'mild', 'high', 'true', 'yes'], ['overcast', 'hot', 'normal', 'false', 'yes'], ['rain', 'mild', 'high', 'true', 'no']] #測試樣本 test_data = [['sunny', 'mild', 'normal', 'false'], ['overcast', 'hot', 'normal', 'true']] #計算資訊熵,傳入的是某個屬性下的分類及對應的個數,{1:2, 2:5, 3:9},返回資訊熵 def entropy_cal(pro): entro = 0 total = sum(list(pro.values())) for key, value in pro.items(): P = value/total entro = entro -P * np.log2(P) return entro #res = entropy_cal({1:1, 2:2, 3:1}) #print(res) #傳入的是訓練樣本,預設最後一列是標籤,返回資訊增益最大的屬性值 def IG(train_data): #print(train_data) #不存在屬性 if len(train_data[0]) <= 1: return -1 #訓練樣本集數量 train_num = len(train_data) #print(train_num) #屬性個數 pro_num = len(train_data[0]) #宣告一個空的字典 pro_dic = {} for i in range(pro_num): for j in range(1, train_num): pro_ind = str(i) if pro_ind not in pro_dic.keys(): pro_dic[pro_ind] = {} if train_data[j][i] in pro_dic[pro_ind].keys(): pro_dic[pro_ind][train_data[j][i]]['total'] += 1 else: pro_dic[pro_ind][train_data[j][i]] = {'total':1} #記錄對應標籤 if train_data[j][-1] in pro_dic[pro_ind][train_data[j][i]].keys(): pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] += 1 else: pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] = 1 #print('--------') #print(pro_dic) #print('========') MAX_IG = 10000 MAX_INDEX = -1 for i in range(pro_num - 1): #對每個屬性計算熵 linshi = 0 for key, value in pro_dic[str(i)].items(): P = value['total']/train_num del value['total'] linshi += P * entropy_cal(value) #print(linshi) if linshi < MAX_IG: MAX_INDEX = i MAX_IG = linshi #print('選擇的屬性是:') #print(MAX_IG) #print(MAX_INDEX) return {'index':MAX_INDEX, 'pro':pro_dic[str(MAX_INDEX)].keys()} def getTree(train, node): #print(tree) res=IG(train) #print(res) index = res['index'] #遍歷列表train,如何 index 對應的標籤都是同一個,那麼就結束 num = len(train) if index != -1: #pro = {'sunny':{'yes':3,'no':2, 'data':[]}, 'overcast':{'yes':4, 'data':[]}, 'rain':{'yes':3,'no':2, 'data':[]}} pro = {} for i in range(num): if train[i][index] in pro.keys(): if train[i][-1] in pro[train[i][index]].keys(): pro[train[i][index]][train[i][-1]] += 1 else: pro[train[i][index]][train[i][-1]] = 1 else: pro[train[i][index]] = {} pro[train[i][index]][train[i][-1]] = 1 temp = copy.deepcopy(train[i]) del temp[index] if 'data' not in pro[train[i][index]].keys(): pro[train[i][index]]['data'] = list() head = copy.deepcopy(train[0]) del head[index] pro[train[i][index]]['data'].append(head) pro[train[i][index]]['data'].append(temp) #print(pro) name = train[0][index] if len(node) == 0: #print(tree) for i in res['pro']: if len(pro[i].keys()) <= 2: key = list(pro[i].keys()) if key[0] == 'data': del key[0] print([name, i, key[0]]) else: newdata = copy.deepcopy(pro[i]['data']) getTree(newdata, [name, i]) else: #print('-----------') #print(res) #print('===========') for i in res['pro']: newnode = copy.deepcopy(node) if len(pro[i].keys()) <= 2: key = list(pro[i].keys()) if key[0] == 'data': del key[0] newnode.append(name) newnode.append(i) newnode.append(key[0]) print(newnode) else: newdata = copy.deepcopy(pro[i]['data']) newnode.append(name) newnode.append(i) getTree(newdata, newnode) def main(): getTree(train_data, []) if __name__ == "__main__": sys.exit(main())
執行結果: