GBDT多分類示例
樣本編號 | 花萼長度(cm) | 花萼寬度(cm) | 花瓣長度(cm) | 花瓣寬度 | 花的種類 |
1 | 5.1 | 3.5 | 1.4 | 0.2 | 山鳶尾 |
2 | 4.9 | 3.0 | 1.4 | 0.2 | 山鳶尾 |
3 | 7.0 | 3.2 | 4.7 | 1.4 | 雜色鳶尾 |
4 | 6.4 | 3.2 | 4.5 | 1.5 | 雜色鳶尾 |
5 | 6.3 | 3.3 | 6.0 | 2.5 | 維吉尼亞鳶尾 |
6 | 5.8 | 2.7 | 5.1 | 1.9 | 維吉尼亞鳶尾 |
Iris數據集
這是一個有6個樣本的三分類問題。我們需要根據這個花的花萼長度,花萼寬度,花瓣長度,花瓣寬度來判斷這個花屬於山鳶尾,雜色鳶尾,還是維吉尼亞鳶尾。具體應用到gbdt多分類算法上面。我們用一個三維向量來標誌樣本的label。[1,0,0] 表示樣本屬於山鳶尾,[0,1,0] 表示樣本屬於雜色鳶尾,[0,0,1] 表示屬於維吉尼亞鳶尾。
gbdt 的多分類是針對每個類都獨立訓練一個 CART Tree。所以這裏,我們將針對山鳶尾類別訓練一個 CART Tree 1。雜色鳶尾訓練一個 CART Tree 2 。維吉尼亞鳶尾訓練一個CART Tree 3,這三個樹相互獨立。
我們以樣本 1 為例。針對 CART Tree1 的訓練樣本是[5.1,3.5,1.4,0.2],label 是 1,最終輸入到模型當中的為[5.1,3.5,1.4,0.2,1]。針對 CART Tree2 的訓練樣本也是[5.1,3.5,1.4,0.2],但是label 為 0,最終輸入模型的為[5.1,3.5,1.4,0.2,0]. 針對 CART Tree 3的訓練樣本也是
下面我們來看 CART Tree1 是如何生成的,其他樹 CART Tree2 , CART Tree 3的生成方式是一樣的。CART Tree的生成過程是從這四個特征中找一個特征做為CART Tree1 的節點。比如花萼長度做為節點。6個樣本當中花萼長度 大於5.1 cm的就是 A類,小於等於 5.1 cm 的是B類。生成的過程其實非常簡單,問題 1.是哪個特征最合適? 2.是這個特征的什麽特征值作為切分點? 即使我們已經確定了花萼長度做為節點。花萼長度本身也有很多值。在這裏我們的方式是遍歷所有的可能性,找到一個最好的特征和它對應的最優特征值可以讓當前式子的值最小。
我們以第一個特征的第一個特征值為例。R1 為所有樣本中花萼長度小於 5.1 cm 的樣本集合,R2 為所有樣本當中花萼長度大於等於 5.1cm 的樣本集合。所以 R1={2},R2={1,3,4,5,6}.
一棵樹,是這樣進行訓練的,來了一個樣本,它有一個標記,將對應標記的放入到對應標記的樹裏面去,放入到其他樹裏面去的時候,標記為0,放入到對應樹裏面去的時候,標記為1;這樣,每次訓練m顆樹,m為總的類別數,訓練k輪下來就有m*k顆樹
GBDT多分類示例