決策樹實戰專案-鳶尾花分類
分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow
也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!
決策樹實戰專案-鳶尾花分類
一、實驗介紹
1.1 實驗內容
決策樹是機器學習中一種簡單而又經典的演算法。本次實驗將帶領瞭解決策樹的基本原理,並學習使用 scikit-learn 來構建一個決策樹分類模型,最後使用此模型預測鳶尾花的種類。
1.2 實驗知識點
- 決策樹的基本原理。
- 決策樹在生成和修剪中使用的 ID3, C4.5 及 CART 演算法。
- 使用 scikit-learn 中提供的決策樹分類器進行例項驗證。
1.3 實驗環境
- python2.7
- Xfce 終端
- ipython
1.4 適合人群
本課程難度為一般,屬於初級級別課程,適合具有 Python 基礎,並對機器學習中決策樹演算法感興趣的使用者。
1.5 程式碼獲取
你可以通過下面命令將程式碼下載到實驗樓環境中,作為參照對比進行學習。
$ wget http://labfile.oss.aliyuncs.com/courses/863/decisionTree.py
二、決策樹基本原理
2.1 決策樹簡介
決策樹是一種特殊的樹形結構,一般由節點和有向邊組成。其中,節點表示特徵、屬性或者一個類。而有向邊包含有判斷條件。如圖所示,決策樹從根節點開始延伸,經過不同的判斷條件後,到達不同的子節點。而上層子節點又可以作為父節點被進一步劃分為下層子節點。一般情況下,我們從根節點輸入資料,經過多次判斷後,這些資料就會被分為不同的類別。這就構成了一顆簡單的分類決策樹。
2.2 決策樹學習
我們將決策數的思想引入到機器學習中,就產生了一種簡單而又經典的預測方法 —— 決策樹學習(Decision Tree Learning),亦簡稱為決策樹。決策樹可以用來解決分類或迴歸問題,分別稱之為分類樹或迴歸樹。其中,分類樹的輸出是一個標量,而回歸樹的一般輸出為一個實數。
通常情況下,決策樹利用損失函式最小的原則建立模型,然後再利用該模型進行預測。決策樹學習通常包含三個階段:特徵選擇、樹的生成,樹的修剪。
2.3 特徵選擇
特徵選擇是建立決策樹之前十分重要的一步。如果是隨機地選擇特徵,那麼所建立決策樹的學習效率將會大打折扣。舉例來講,銀行採用決策樹來解決信用卡審批問題,判斷是否向某人發放信用卡可以根據其年齡、工作單位、是否有不動產、歷史信貸情況等特徵決定。而選擇不同的特徵,後續生成的決策樹就會不一致,這種不一致最終會影響到決策樹的分類效率。
通常我們在選擇特徵時,會考慮到兩種不同的指標,分別為:資訊增益和資訊增益比。要想弄清楚這兩個概念,我們就不得不提到資訊理論中的另一個十分常見的名詞 —— 熵。
熵(Entropy)是表示隨機變數不確定性的度量。簡單來講,熵越大,隨機變數的不確定性就越大。而特徵 A 對於某一訓練集 D 的資訊增益 g(D, A) 定義為集合 D 的熵 H(D) 與特徵 A 在給定條件下 D 的熵 H(D/A) 之差。
上面這段定義讀起來很拗口,也不是特別容易理解。那麼,下面我使用更通俗的語言概述一下。簡單來講,每一個特徵針對訓練資料集的前後資訊變化的影響是不一樣的,資訊增益越大,即代表這種影響越大。而影響越大,就表明該特徵更加重要。
2.4 生成演算法
當我們瞭解資訊增益的概念之後,我們就可以學習決策樹的生成演算法了。其中,最經典的就數 John Ross Quinlan 提出的 ID3 演算法,這個演算法的核心理論即源於上面提到的資訊增益。
ID3 演算法通過遞迴的方式建立決策樹。建立時,從根節點開始,對節點計算每個獨立特徵的資訊增益,選擇資訊增益最大的特徵作為節點特徵。接下來,對該特徵施加判斷條件,建立子節點。然後針對子節點再此使用資訊增益進行判斷,直到所有特徵的資訊增益很小或者沒有特徵時結束,這樣就逐步建立一顆完整的決策樹。
除了從資訊增益演化而來的 ID3 演算法,還有一種常見的演算法叫 C4.5。C4.5 演算法同樣由 John Ross Quinlan 發明,但它使用了資訊增益比來選擇特徵,這被看成是 ID3 演算法的一種改進。
ID3 和 C4.5 演算法簡單高效,但是他倆均存在一個缺點,那就是用“完美去造就了另一個不完美”。這兩個演算法從資訊增益和資訊增益比開始,對整個訓練集進行的分類,擬合出來的模型針對該訓練集的確是非常完美的。但是,這種完美就使得整體模型的複雜度較高,而對其他資料集的預測能力就降低了,也就是我們常說的過擬合而使得模型的泛化能力變弱。
當然,過擬合的問題也是可以解決的,那就是對決策樹進行修剪。
2.5 決策樹修剪
決策樹的修剪,其實就是通過優化損失函式來去掉不必要的一些分類特徵,降低模型的整體複雜度。修剪的方式,就是從樹的葉節點出發,向上回縮,逐步判斷。如果去掉某一特徵後,整棵決策樹所對應的損失函式更小,那就就將該特徵及帶有的分支剪掉。
由於 ID3 和 C4.5 只能生成決策樹,而修剪需要單獨進行,這也就使得過程更加複雜了。1984年,Breiman 提出了 CART 演算法,使這個過程變得可以一步到位。CART 演算法本身就包含了決策樹的生成和修剪,並且可以同時被運用到分類樹和迴歸樹。這就是和 ID3 及 C4.5 之間的最大區別。
CART 演算法在生成樹的過程中,分類樹採用了基尼指數(Gini Index)最小化原則,而回歸樹選擇了平方損失函式最小化原則。基尼指數其實和前面提到的熵的概念是很相似的。簡單概述區別的話,就是數值相近但不同,而基尼指數在運算過程中的速度會更快一些。
CART 演算法也包含了樹的修剪。CART 演算法從完全生長的決策樹底端剪去一些子樹,使得模型更加簡單。而修剪這些子樹時,是每次去除一顆,逐步修剪直到根節點,從而形成一個子樹序列。最後,對該子樹序列進行交叉驗證,再選出最優的子樹作為最終決策樹。
三、鳶尾花分類實驗
如果你感覺理論看起來比較費勁,不用擔心。接下來就帶領你用非常少的程式碼量來構建一個決策樹分類模型,實現對鳶尾花分類。
3.1 資料集簡介
鳶尾花資料集是機器學習領域一個非常經典的分類資料集。接下來,我們就用這個訓練集為基礎,一步一步地訓練一個機器學習模型。首先,我們來看一下該資料集的基本構成。資料集名稱的準確名稱為Iris Data Set,總共包含 150 行資料。每一行資料由 4 個特徵值及一個目標值組成。其中 4 個特徵值分別為:萼片長度、萼片寬度、花瓣長度、花瓣寬度。而目標值及為三種不同類別的鳶尾花,分別為:Iris Setosa,Iris Versicolour,Iris Virginica。
3.2 資料獲取及劃分
你可以通過著名的 UCI 機器學習資料集網站下載該資料集。本實驗中,為了更加便捷地實驗。我們直接實驗 scikit-learn 提供的方法匯入該資料集即可。開啟實驗環境右下角的選單 > 附件 > ipython,依次鍵入程式碼。
# -*- coding: utf-8 -*-from sklearn import datasets #匯入方法類iris = datasets.load_iris() #載入 iris 資料集iris_feature = iris.data #特徵資料iris_target = iris.target #分類資料
接下來,你可以直接通過 print iris_target
檢視一下花的分類資料。這裡,scikit-learn 已經將花的原名稱進行了轉換,其中 0, 1, 2 分別代表 Iris Setosa, Iris Versicolour 和 Iris Virginica。
你會發現,這些資料是按照鳶尾花類別的順序排列的。所以,如果我們將其直接劃分為訓練集和資料集的話,就會造成資料的分佈不均。詳細來講,直接劃分容易造成某種型別的花在訓練集中一次都未出現,訓練的模型就永遠不可能預測出這種花來。你可能會想到,我們將這些資料大亂後再劃分訓練集和資料集。當然,更方便地,scikit-learn 為我們提供了訓練集和資料集的方法。
from sklearn.cross_validation import train_test_splitfeature_train, feature_test, target_train, target_test = train_test_split(iris_feature, iris_target, test_size=0.33, random_state=42)
其中,feature_train
, feature_test
, target_train
,target_test
分別代表訓練集特徵、測試集特徵、訓練集目標值、驗證集特徵。test_size
引數代表劃分到測試集資料佔全部資料的百分比,你也可以用train_size
來指定訓練集所佔全部資料的百分比。一般情況下,我們會將整個訓練集劃分為 70% 訓練集和 30% 測試集。最後的 random_state
引數表示亂序程度。
資料集劃分之後,我們可以再次執行 print iris_target
看一下結果。
現在,你會發現花的種類已經變成了亂序狀態,並且只包含有整個訓練集的 70% 資料。
3.2 模型訓練及預測
劃分完訓練集和測試集之後,我們就可以開始預測了。首先是從 scikit-learn 中匯入決策樹分類器。然後實驗 fit 方法和 predict 方法對模型進行訓練和預測。
# -*- coding: utf-8 -*-from sklearn.tree import DecisionTreeClassifierdt_model = DecisionTreeClassifier() # 所以引數均置為預設狀態dt_model.fit(feature_train,target_train) # 使用訓練集訓練模型predict_results = dt_model.predict(feature_test) # 使用模型對測試集進行預測
DecisionTreeClassifier()
模型方法中也包含非常多的引數值。例如:
criterion = gini/entropy
可以用來選擇用基尼指數或者熵來做損失函式。splitter = best/random
用來確定每個節點的分裂策略。支援“最佳”或者“隨機”。max_depth = int
用來控制決策樹的最大深度,防止模型出現過擬合。min_samples_leaf = int
用來設定葉節點上的最少樣本數量,用於對樹進行修剪。
我們可以將預測結果和測試集的真實值分別輸出,對照比較。
當然,我們可以通過 scikit-learn 中提供的評估計算方法檢視預測結果的準確度。
from sklearn.metrics import accuracy_scoreprint accuracy_score(predict_results, target_test)
其實,在 scikit-learn 中的分類決策樹模型就帶有 score 方法,只是傳入的引數和 accuracy_score()
不太一致。
scores = dt_model.score(feature_test, target_test)
你可以看出兩種準確度方法輸入引數的區別。一般情況下,模型預測的準確度會和多方面因素相關。首先是資料集質量,本實驗中,我們使用的資料集非常規範,幾乎不包含噪聲,所以預測準確度非常高。其次,模型的引數也會對預測結果的準確度造成影響。
六、實驗總結
首先通過決策樹的原理,加深了對介紹機器學習中決策樹演算法的理解。並採用 scikit-learn 中提供的決策樹分類器構建預測模型,實現對鳶尾花進行分類。
七、課後習題
- 嘗試通過修改
DecisionTreeClassifier()
方法裡面的值,檢視模型引數對實驗結果帶來的影響。 - 嘗試載入 scikit-learn 中提供的另一個著名的 digits 資料集,同樣實驗決策樹分類器實現手寫字型識別實驗。
八、參考連結
- 《統計學習方法》,李航,清華大學出版社
- 鳶尾花資料集,維基百科
- 周志華《機器學習》習題解答:Ch4.4 - 程式設計實現CART演算法與剪枝操作