1. 程式人生 > >決策樹 ID3學習筆記

決策樹 ID3學習筆記

最近開始看機器學習方面的知識,決策樹(DT)主要有三種演算法 ID3(Iterative Dichotomiser 3)、 C4.5、CART哈。

包括的知識有,資訊熵、資訊增益、資訊增益率、基尼指數概念,另外還有預剪枝、後剪枝等。

決策樹演算法的優點:

1:易於理解,使用白盒模型,相比之下,在一個黑盒子模型(例如人工神經網路),結果可能更難以解釋
2:需要準備的資料量不大
3:能夠處理數字和資料的類別(需要做相應的轉變),而其他演算法分析的資料集往往是隻有一種型別的變數
4:能夠處理多輸出的問題


決策樹演算法的缺點:
1:決策樹演算法學習者可以建立複雜的樹,容易過擬合,為了避免這種問題,出現了剪枝的概念,即設定一個葉子結點所需要的最小數目或者設定樹的最大深度
2:決策樹的結果可能是不穩定的,因為在資料中一個很小的變化可能導致生成一個完全不同的樹,這個問題可以通過使用整合決策樹來解決
3:實際決策樹學習演算法是基於啟發式演算法,如貪婪演算法,尋求在每個節點上的區域性最優決策。這樣的演算法不能保證返回全域性最優決策樹。

4:決策樹學習者很可能在某些類占主導地位時建立有有偏異的樹,因此建議用平衡的資料訓練決策樹

 

剪枝:

預剪枝:在選擇分裂屬性的時候,計算增加該屬性,驗證集的準確率是否能夠提高,如果可以提高,則增加該屬性;否則不增加

後剪枝:按照演算法(c4.5、cart 等)生成決策樹,然後從下往上對每個分支屬性進行判斷,如果去掉該屬性,驗證集的準確率是否降低,如果不降低就保留,否則去掉。

一般情況下,預剪枝容易underfitting,而後剪枝不存在這種問題,但是後剪枝計算量通常會比預剪枝大。

 

ID3演算法最早是由羅斯昆(J. Ross Quinlan)於1975年在悉尼大學提出的一種分類預測演算法,演算法以資訊理論為基礎,其核心是“資訊熵”。ID3演算法通過計算每個屬性的資訊增益,認為資訊增益高的是好屬性,每次劃分選取資訊增益最高的屬性為劃分標準,重複這個過程,直至生成一個能完美分類訓練樣例的決策樹。首先介紹兩個概念:

理解了上面兩個概念就好辦了,下面是一個具體的例子,加入根據 outlook,temp,hum,windy 屬性來決定去不去paly。

直接上程式碼,寫的比較搓,就是個示例,哈:

import copy, numpy as np
import sys
import pdb

#pdb.set_trace()
'''訓練樣本,前兩列對應是屬性值,最後一列表示分類,下面的值分別對應
              Outlook, Temperature, Humidity, Windy, Play'''

train_data = [['Outlook', 'Temperature', 'Humidity', 'Windy', 'Play'],
              ['sunny', 'hot', 'high', 'false', 'no'],
              ['sunny', 'hot', 'high', 'true', 'no'],
              ['overcast', 'hot', 'high', 'false', 'yes'], 
              ['rain', 'mild', 'high', 'false', 'yes'], 
              ['rain', 'cool', 'normal', 'false', 'yes'],
              ['rain', 'cool', 'normal', 'true', 'no'],
              ['overcast', 'cool', 'normal', 'true', 'yes'],
              ['sunny', 'mild', 'high', 'false', 'no'],
              ['sunny', 'cool', 'normal', 'false', 'yes'],
              ['rain', 'mild', 'normal', 'false', 'yes'],
              ['sunny', 'mild', 'normal', 'true', 'yes'],
              ['overcast', 'mild', 'high', 'true', 'yes'],
              ['overcast', 'hot', 'normal', 'false', 'yes'],
              ['rain', 'mild', 'high', 'true', 'no']]
#測試樣本
test_data = [['sunny', 'mild', 'normal', 'false'], ['overcast', 'hot', 'normal', 'true']]

#計算資訊熵,傳入的是某個屬性下的分類及對應的個數,{1:2, 2:5, 3:9},返回資訊熵
def entropy_cal(pro):
	entro = 0
	total = sum(list(pro.values()))
	for key, value in pro.items():
		P = value/total
		entro = entro -P * np.log2(P)
	return entro

#res = entropy_cal({1:1, 2:2, 3:1})
#print(res)


#傳入的是訓練樣本,預設最後一列是標籤,返回資訊增益最大的屬性值
def IG(train_data):
	#print(train_data)
	#不存在屬性
	if len(train_data[0]) <= 1:
		return -1
	#訓練樣本集數量
	train_num = len(train_data)
	#print(train_num)
	#屬性個數
	pro_num = len(train_data[0])
    #宣告一個空的字典
	pro_dic = {}
	for i in range(pro_num):
		for j in range(1, train_num):
			pro_ind = str(i)

			if pro_ind not in pro_dic.keys():
				pro_dic[pro_ind] = {}
			
			if train_data[j][i] in pro_dic[pro_ind].keys():
				pro_dic[pro_ind][train_data[j][i]]['total'] += 1
			else:
				pro_dic[pro_ind][train_data[j][i]] = {'total':1}
			#記錄對應標籤
			if train_data[j][-1] in pro_dic[pro_ind][train_data[j][i]].keys():
				pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] += 1
			else:
				pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] = 1

	#print('--------')
	#print(pro_dic)
	#print('========')
	MAX_IG = 10000
	MAX_INDEX = -1
	for i in range(pro_num - 1):
		#對每個屬性計算熵
		linshi = 0
		for key, value in pro_dic[str(i)].items():
			P = value['total']/train_num
			del value['total']
			linshi += P * entropy_cal(value)
		#print(linshi)
		if linshi < MAX_IG:
			MAX_INDEX = i
			MAX_IG = linshi
	#print('選擇的屬性是:')
	#print(MAX_IG)
	#print(MAX_INDEX)
	return {'index':MAX_INDEX, 'pro':pro_dic[str(MAX_INDEX)].keys()}


		
def getTree(train, node):
	#print(tree)
	res=IG(train)
	#print(res)
	index = res['index']
	#遍歷列表train,如何 index 對應的標籤都是同一個,那麼就結束
	num = len(train)

	if index != -1:
		#pro = {'sunny':{'yes':3,'no':2, 'data':[]}, 'overcast':{'yes':4, 'data':[]}, 'rain':{'yes':3,'no':2, 'data':[]}}
		pro = {}
		for i in range(num):
			if train[i][index] in pro.keys():
				if train[i][-1] in pro[train[i][index]].keys():
					pro[train[i][index]][train[i][-1]] += 1
				else:
					pro[train[i][index]][train[i][-1]] = 1
			else:
				pro[train[i][index]] = {}
				pro[train[i][index]][train[i][-1]] = 1

			temp = copy.deepcopy(train[i])
			del temp[index]
			if 'data' not in pro[train[i][index]].keys():
				pro[train[i][index]]['data'] = list()
				head = copy.deepcopy(train[0])
				del head[index]
				pro[train[i][index]]['data'].append(head)
			pro[train[i][index]]['data'].append(temp)
	#print(pro)
	name = train[0][index]
	if len(node) == 0:
		#print(tree)
		for i in res['pro']:
			if len(pro[i].keys()) <= 2:
				key = list(pro[i].keys())
				if key[0] == 'data':
					del key[0]
				print([name, i, key[0]])
			else:
				newdata = copy.deepcopy(pro[i]['data'])
				getTree(newdata, [name, i])
	else:
		#print('-----------')
		#print(res)
		#print('===========')
		for i in res['pro']:
			newnode = copy.deepcopy(node)
			if len(pro[i].keys()) <= 2:
				key = list(pro[i].keys())
				if key[0] == 'data':
					del key[0]
				newnode.append(name)
				newnode.append(i)
				newnode.append(key[0])
				print(newnode)
			else:
				newdata = copy.deepcopy(pro[i]['data'])
				newnode.append(name)
				newnode.append(i)
				getTree(newdata, newnode)

def main():
	getTree(train_data, [])


if __name__ == "__main__":
    sys.exit(main())

執行結果: