決策樹理解與python實現
程式碼實現請直接移步博文末尾
在機器學習領域,決策樹是用於資料分類、預測的模型。決策樹演算法通過分析訓練集的各個資料特徵的不同,由原始資料集構造出一個樹形結構,比如我們分析一封郵件是否為垃圾郵件時,可以根據傳送方域名、郵件主題等方式區分郵件是否為垃圾郵件,新資料通過使用構造出的決策樹模型來進行預測。
決策樹演算法的關鍵主要是尋找一個最合適的資料特徵將資料集區分開來。我使用以下資料進行的測試:
使用以下資料區分西瓜是否為好瓜
編號 | 色澤 | 根蒂 | 敲聲 | 紋理 | 臍部 | 觸感 | 好瓜 |
---|---|---|---|---|---|---|---|
1 | 青綠 | 蜷縮 | 濁響 | 清晰 | 凹陷 | 硬滑 | 是 |
2 | 烏黑 | 蜷縮 | 沉悶 | 清晰 | 凹陷 | 硬滑 | 是 |
ps: 以上為用到的部分資料
決策樹劃分資料集的方法虛擬碼:
function split_data()
if 所有資料都屬於同一個分類:
return 分類名稱
else:
尋找劃分資料集的最好特徵
劃分資料集
建立分支節點
for 每個劃分的子集:
呼叫函式split_data(),並將返回結果增加到分支節點中
return 分支節點
上述示例資料集中,是否為好瓜就是分類名稱,色澤、根蒂則是用於劃分資料集的特徵。
所以決策樹的基本思想就是不斷遞迴尋找最好的特徵,然後使用特徵劃分資料整合多個,使得子資料集的熵最小。
上面我們引入了“熵”這個詞,這個詞用於代表資料的混亂程度(在化學等領域也有這個詞,意思差不多),我們上面所說的使子資料集們的熵最小,意思也就是使用一個特徵劃分資料集,使每個子資料集擁有的分類儘可能少。
在這裡,我們將熵定義為資訊的期望值(ps:這是夏農定義的),在計算熵之前,我們首先得知道資訊的定義。如果待分類的事物可能劃分在多個分類中,這符號
其中
為了計算熵,我們需要計算所有類別所有可能值包含的資訊期望值,有下面的公式:
同時我們引入一個詞“資訊增益”,資訊增益指的是在劃分資料集之前之後資訊發生的變化。簡單來說就是使用某個屬性劃分資料集之後的資訊熵與劃分之前的資訊熵之差。資訊增益越大,表明資料的混亂程度減少的越多,說明更適合使用該屬性來劃分當前的資料集。
所以構建決策樹的主要思想就是,我們通過每次遍歷所有屬性,嘗試劃分資料集,尋找資訊增益最大的劃分,不斷遞迴劃分資料,直到所有屬性被用完(下面我的程式碼使用的ID3演算法,每次劃分消耗一個屬性,當然也有不消耗屬性的演算法)或者各子資料集內部資料型別相同。
根據西瓜資料集生成的決策樹樣式:
python程式碼實現(ID3演算法):
# -*- coding:utf-8 -*-
from math import log
def calc_shannon_ent(dataSet):
"""
計算夏農熵
:param dataSet: 待計算的資料集
:return: 夏農熵
"""
data_nums = len(dataSet)
label_count = {}
for i in dataSet:
label = i[-1]
if label not in label_count.keys():
label_count[label] = 0
label_count[label] += 1
shannon_ent = 0.0
for i in label_count.keys():
prob = float(label_count[i])/data_nums
shannon_ent -= prob*log(prob, 2)
return shannon_ent
def split_data_set(dataSet, index, value):
"""
消耗指定屬性,劃分資料集,返回指定值的資料
:param dataSet: 待劃分的資料集
:param index: 指定屬性所在的索引
:param value: 返回的資料集對應屬性的值
:return: new_data_set: 新資料集
"""
new_data_set = []
for i in dataSet:
if i[index] == value:
new_data = i[:index] # 消耗指定的資料,將資料集劃分
new_data.extend(i[index+1:])
new_data_set.append(new_data)
return new_data_set
def choose_best_split_method(dataSet):
"""
選擇最好的用於分類的屬性,返回屬性索引
:param dataSet: 資料集
:return: 用於分類的索引
"""
dataSet_num = len(dataSet)
attribute_num = len(dataSet[0]) - 1 # 資料集的屬性個數
base_entropy = calc_shannon_ent(dataSet) # 未分類的夏農熵
best_information_gain = 0.0 # 最好的資訊增益的值
best_index = -1 # 最好資訊增益對應的屬性索引
for i in range(attribute_num):
attributes = [one_data[i] for one_data in dataSet]
attributes = set(attributes)
entropy = 0.0 # 夏農熵
for value in attributes:
sub_data_set = split_data_set(dataSet, i, value)
prob = len(sub_data_set)/float(dataSet_num)
entropy += prob*calc_shannon_ent(sub_data_set) # 對劃分的兩個資料集夏農熵求均值
information_gain = base_entropy - entropy # 求此次劃分的資訊增益
if information_gain > best_information_gain:
best_information_gain = information_gain
best_index = i
return best_index
def majority_cnt(class_list):
"""
返回出現最多次數的分類名稱
:param class_list: 資料集中所有目標值分類的列表
:return: 出現次數最多的目標分類名稱
"""
class_count = {}
for i in class_list:
if i not in class_count.keys():
class_count[i] = 0
class_count[i] += 1
sorted_class_count = sorted(class_count.iteritems(), key=lambda x: x[1])
return sorted_class_count[0][0]
def create_tree(dataSet, labels):
"""
建立決策樹
:param dataSet: 資料集
:param labels: 所有的屬性名稱集
:return: 決策樹字典
"""
class_list = [item[-1] for item in dataSet]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
if len(dataSet[0]) == 1: # 每次決策樹劃分分支都會消耗一個屬性,最後只剩一個目標屬性,說明無法劃分分支了
return majority_cnt(class_list)
best_index = choose_best_split_method(dataSet)
best_label = labels[best_index]
mytree = {best_label: {}}
del labels[best_index]
all_values = [item[best_index] for item in dataSet]
all_values = set(all_values)
for i in all_values:
sub_labels = labels[:]
mytree[best_label][i] = create_tree(split_data_set(dataSet, best_index, i), sub_labels)
return mytree
def classify(decision_tree, labels, one_data):
"""
使用決策樹進行分類
:param decision_tree: 決策樹
:param labels: 屬性名稱列表
:param one_data: 需要分類的資料
:return: 分類結果
"""
first = decision_tree.keys()[0]
second_dict = decision_tree[first]
first_index = labels.index(first)
for i in second_dict:
if one_data[first_index] == i:
if type(second_dict[i]).__name__ == 'dict':
class_label = classify(second_dict[i], labels, one_data)
else:
class_label = second_dict[i]
return class_label
if __name__ == '__main__':
# 使用資料集中的前部分資料構造決策樹,使用最後一條資料檢測決策樹的正確性
dataSet = []
with open('watermelon.txt', 'r') as fp:
data = fp.readline()
while data:
dataSet.append(data.split(' ')[1:])
data = fp.readline()
test_set = dataSet[-1]
dataSet = dataSet[:-1]
tree = create_tree(dataSet, [u'色澤', u'根蒂', u'敲聲', u'紋理', u'臍部', u'觸感'])
print classify(tree, [u'色澤', u'根蒂', u'敲聲', u'紋理', u'臍部', u'觸感'], test_set)
本文章僅為個人理解,如有錯誤歡迎批評指正,轉載請註明出處