1. 程式人生 > >cart樹回歸及其剪枝的python實現

cart樹回歸及其剪枝的python實現

mat 接下來 更多 split 討論 也有 其中 程序 target

轉自穆晨

閱讀目錄

  • 前言
  • 回歸樹
  • 回歸樹的優化工作 - 剪枝
  • 模型樹
  • 回歸樹 / 模型樹的使用
  • 小結
回到頂部

前言

前文討論的回歸算法都是全局且針對線性問題的回歸,即使是其中的局部加權線性回歸法,也有其弊端(具體請參考前文)

采用全局模型會導致模型非常的臃腫,因為需要計算所有的樣本點,而且現實生活中很多樣本都有大量的特征信息。

另一方面,實際生活中更多的問題都是非線性問題。

針對這些問題,有了樹回歸系列算法。

回到頂部

回歸樹

在先前決策樹的學習中,構建樹是采用的 ID3 算法。在回歸領域,該算法就有個問題,就是派生子樹是按照所有可能值來進行派生。

因此 ID3 算法無法處理連續性數據。

故可使用二元切分法,以某個特定值為界進行切分。在這種切分法下,子樹個數小於等於2。

除此之外,再修改擇優原則香農熵 (因為數據變為連續型的了),便可將樹構建成一棵可用於回歸的樹,這樣一棵樹便叫做回歸樹。

構建回歸樹的偽代碼:

1 找到最佳的待切分特征:
2     如果該節點不能再分,將此節點存為葉節點。
3     執行二元切分
4     左右子樹分別遞歸調用此函數

二元切分的偽代碼:

1 對每個特征:
2     對每個特征值:
3         將數據集切成兩份
4         計算切分誤差
5         如果當前誤差小於最小誤差,則更新最佳切分以及最小誤差。

特別說明,終止劃分 (並直接建立葉節點)有三種情況:
1. 特征值劃分完畢
2. 劃分子集太小
3. 劃分後誤差改進不大
這幾個操作被稱做 "預剪枝"。
  下面給出一個完整的回歸樹的小程序:

技術分享
  1 #!/usr/bin/env python
  2 # -*- coding:UTF-8 -*-
  3 
  4 ‘‘‘
  5 Created on 20**-**-**
  6 
  7 @author: fangmeng
  8 ‘‘‘
  9 
 10 from numpy import *
 11 
 12 def loadDataSet(fileName):
 13     ‘載入測試數據‘
 14     
 15     dataMat = []
 16     fr = open(fileName)
 17     for line in fr.readlines():
 18         curLine = line.strip().split(‘\t‘)
 19         # 所有元素轉換為浮點類型(函數編程)
 20         fltLine = map(float,curLine)
 21         dataMat.append(fltLine)
 22     return dataMat
 23 
 24 #============================
 25 # 輸入:
 26 #        dataSet: 待切分數據集
 27 #        feature: 切分特征序號
 28 #        value:    切分值
 29 # 輸出:
 30 #        mat0,mat1: 切分結果
 31 #============================
 32 def binSplitDataSet(dataSet, feature, value):
 33     ‘切分數據集‘
 34     
 35     mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
 36     mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
 37     return mat0,mat1
 38 
 39 #========================================
 40 # 輸入:
 41 #        dataSet: 數據集
 42 # 輸出:
 43 #        mean(dataSet[:,-1]): 均值(也就是葉節點的內容)
 44 #========================================
 45 def regLeaf(dataSet):
 46     ‘生成葉節點‘
 47     
 48     return mean(dataSet[:,-1])
 49 
 50 #========================================
 51 # 輸入:
 52 #        dataSet: 數據集
 53 # 輸出:
 54 #        var(dataSet[:,-1]) * shape(dataSet)[0]: 平方誤差
 55 #========================================
 56 def regErr(dataSet):
 57     ‘計算平方誤差‘
 58     
 59     return var(dataSet[:,-1]) * shape(dataSet)[0]
 60 
 61 #========================================
 62 # 輸入:
 63 #        dataSet: 數據集
 64 #        leafType: 葉子節點生成器
 65 #        errType: 誤差統計器
 66 #        ops: 相關參數
 67 # 輸出:
 68 #        bestIndex: 最佳劃分特征 
 69 #        bestValue: 最佳劃分特征值
 70 #========================================
 71 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
 72     ‘選擇最優劃分‘
 73     
 74     # 獲得相關參數中的最大樣本數和最小誤差效果提升值
 75     tolS = ops[0]; 
 76     tolN = ops[1]
 77     
 78     # 如果所有樣本點的值一致,那麽直接建立葉子節點。
 79     if len(set(dataSet[:,-1].T.tolist()[0])) == 1: 
 80         return None, leafType(dataSet)
 81     
 82     m,n = shape(dataSet)
 83     # 當前誤差
 84     S = errType(dataSet)
 85     # 最小誤差
 86     bestS = inf; 
 87     # 最小誤差對應的劃分方式
 88     bestIndex = 0; 
 89     bestValue = 0
 90     
 91     # 對於所有特征
 92     for featIndex in range(n-1):
 93         # 對於某個特征的所有特征值
 94         for splitVal in set(dataSet[:,featIndex]):
 95             # 劃分
 96             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
 97             # 如果劃分後某個子集中的個數不達標
 98             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
 99             # 當前劃分方式的誤差
100             newS = errType(mat0) + errType(mat1)
101             # 如果這種劃分方式的誤差小於最小誤差
102             if newS < bestS: 
103                 bestIndex = featIndex
104                 bestValue = splitVal
105                 bestS = newS
106     
107     # 如果當前劃分方式還不如不劃分時候的誤差效果
108     if (S - bestS) < tolS: 
109         return None, leafType(dataSet)
110     # 按照最優劃分方式進行劃分
111     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
112     # 如果劃分後某個子集中的個數不達標
113     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
114         return None, leafType(dataSet)
115     
116     return bestIndex,bestValue
117 
118 #========================================
119 # 輸入:
120 #        dataSet: 數據集
121 #        leafType: 葉子節點生成器
122 #        errType: 誤差統計器
123 #        ops: 相關參數
124 # 輸出:
125 #        retTree: 回歸樹
126 #========================================
127 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
128     ‘構建回歸樹‘
129     
130     # 選擇最佳劃分方式
131     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
132     # feat為None的時候無需劃分返回葉子節點
133     if feat == None: return val #if the splitting hit a stop condition return val
134     
135     # 遞歸調用構建函數並更新樹
136     retTree = {}
137     retTree[‘spInd‘] = feat
138     retTree[‘spVal‘] = val
139     lSet, rSet = binSplitDataSet(dataSet, feat, val)
140     retTree[‘left‘] = createTree(lSet, leafType, errType, ops)
141     retTree[‘right‘] = createTree(rSet, leafType, errType, ops)
142     
143     return retTree  
144 
145 def test():
146     ‘展示結果‘
147     
148     # 載入數據
149     myDat = loadDataSet(‘/home/fangmeng/ex0.txt‘)
150     # 構建回歸樹
151     myDat = mat(myDat)
152     
153     print createTree(myDat)
154     
155     
156 if __name__ == ‘__main__‘:
157     test()
技術分享

測試結果:

技術分享

回到頂部

回歸樹的優化工作 - 剪枝

在上面的代碼中,終止遞歸的條件中已經加入了重重的 "剪枝" 工作。

這些在建樹的時候的剪枝操作通常被成為預剪枝。這是很有很有必要的,經過預剪枝的樹幾乎就是沒有預剪枝樹的大小的百分之一甚至更小,而性能相差無幾。

而在樹建立完畢之後,基於訓練集和測試集能做更多更高效的剪枝工作,這些工作叫做 "後剪枝"。

可見,剪枝是一項較大的工作量,是對樹非常關鍵的優化過程。

後剪枝過程的偽代碼如下:

1 基於已有的樹切分測試數據:
2     如果存在任一子集是一棵樹,則在該子集上遞歸該過程。
3     計算將當前兩個葉節點合並後的誤差
4     計算不合並的誤差
5     如果合並會降低誤差,則將葉節點合並。

具體實現函數如下:

技術分享
 1 #===================================
 2 # 輸入:
 3 #        obj: 判斷對象
 4 # 輸出:
 5 #        (type(obj).__name__==‘dict‘): 判斷結果
 6 #===================================
 7 def isTree(obj):
 8     ‘判斷對象是否為樹類型‘
 9     
10     return (type(obj).__name__==‘dict‘)
11 
12 #===================================
13 # 輸入:
14 #        tree: 處理對象
15 # 輸出:
16 #        (tree[‘left‘]+tree[‘right‘])/2.0: 坍塌後的替代值
17 #===================================
18 def getMean(tree):
19     ‘坍塌處理‘
20     
21     if isTree(tree[‘right‘]): tree[‘right‘] = getMean(tree[‘right‘])
22     if isTree(tree[‘left‘]): tree[‘left‘] = getMean(tree[‘left‘])
23     
24     return (tree[‘left‘]+tree[‘right‘])/2.0
25   
26 #===================================
27 # 輸入:
28 #        tree: 處理對象
29 #        testData: 測試數據集
30 # 輸出:
31 #        tree: 剪枝後的樹
32 #===================================  
33 def prune(tree, testData):
34     ‘後剪枝‘
35     
36     # 無測試數據則坍塌此樹
37     if shape(testData)[0] == 0: 
38         return getMean(tree)
39     
40     # 若左/右子集為樹類型
41     if (isTree(tree[‘right‘]) or isTree(tree[‘left‘])):
42         # 劃分測試集
43         lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘])
44     # 在新樹新測試集上遞歸進行剪枝
45     if isTree(tree[‘left‘]): tree[‘left‘] = prune(tree[‘left‘], lSet)
46     if isTree(tree[‘right‘]): tree[‘right‘] =  prune(tree[‘right‘], rSet)
47     
48     # 如果兩個子集都是葉子的話,則在進行誤差評估後決定是否進行合並。
49     if not isTree(tree[‘left‘]) and not isTree(tree[‘right‘]):
50         lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘])
51         errorNoMerge = sum(power(lSet[:,-1] - tree[‘left‘],2)) +sum(power(rSet[:,-1] - tree[‘right‘],2))
52         treeMean = (tree[‘left‘]+tree[‘right‘])/2.0
53         errorMerge = sum(power(testData[:,-1] - treeMean,2))
54         if errorMerge < errorNoMerge: 
55             return treeMean
56         else: return tree
57     else: return tree
技術分享 回到頂部

模型樹

這也是一種很棒的樹回歸算法。

該算法將所有的葉子節點不是表述成一個值,而是對葉子部分節點建立線性模型。比如可以是最小二乘法的基本線性回歸模型。

這樣在葉子節點裏存放的就是一組線性回歸系數了。非葉子節點部分構造就和回歸樹一樣。

這個是上面建立回歸樹算法的函數頭:

createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):

對於模型樹,只需要修改修改 leafType(葉節點構造器) 和 errType(誤差分析器) 的實現即可,分別對應如下modelLeaf 函數和 modelErr 函數:

技術分享
 1 #=========================
 2 # 輸入:
 3 #        dataSet: 測試集
 4 # 輸出:
 5 #        ws,X,Y: 回歸模型
 6 #=========================
 7 def linearSolve(dataSet):
 8     ‘輔助函數,用於構建線性回歸模型。‘
 9     
10     m,n = shape(dataSet)
11     X = mat(ones((m,n))); 
12     Y = mat(ones((m,1)))
13     X[:,1:n] = dataSet[:,0:n-1]; 
14     Y = dataSet[:,-1]
15     xTx = X.T*X
16     if linalg.det(xTx) == 0.0:
17         raise NameError(‘系數矩陣不可逆‘)
18     ws = xTx.I * (X.T * Y)
19     return ws,X,Y
20 
21 #=======================
22 # 輸入:
23 #       dataSet: 數據集
24 # 輸出:
25 #        ws: 回歸系數
26 #=======================
27 def modelLeaf(dataSet):
28     ‘葉節點構造器‘
29     
30     ws,X,Y = linearSolve(dataSet)
31     return ws
32 
33 #=======================================
34 # 輸入:
35 #       dataSet: 數據集
36 # 輸出:
37 #        sum(power(Y - yHat,2)): 平方誤差
38 #=======================================
39 def modelErr(dataSet):
40     ‘誤差分析器‘
41     
42     ws,X,Y = linearSolve(dataSet)
43     yHat = X * ws
44     return sum(power(Y - yHat,2))
技術分享 回到頂部

回歸樹 / 模型樹的使用

前面的工作主要介紹了兩種樹 - 回歸樹,模型樹的構建,下面進一步學習如何利用這些樹來進行預測。

當然,本質也就是遞歸遍歷樹。

下為遍歷代碼,通過修改參數設置要使用並傳遞進來的是回歸樹還是模型樹:

技術分享
 1 #==============================
 2 # 輸入:
 3 #       model: 葉子
 4 #       inDat: 測試數據
 5 # 輸出:
 6 #        float(model): 葉子值
 7 #==============================
 8 def regTreeEval(model, inDat):
 9     ‘回歸樹預測‘
10     
11     return float(model)
12 
13 #==============================
14 # 輸入:
15 #       model: 葉子
16 #       inDat: 測試數據
17 # 輸出:
18 #        float(X*model): 葉子值
19 #==============================
20 def modelTreeEval(model, inDat):
21     ‘模型樹預測‘
22     n = shape(inDat)[1]
23     X = mat(ones((1,n+1)))
24     X[:,1:n+1]=inDat
25     return float(X*model)
26 
27 #==============================
28 # 輸入:
29 #        tree: 待遍歷樹
30 #        inDat: 測試數據
31 #        modelEval: 葉子值獲取器
32 # 輸出:
33 #        分類結果
34 #==============================
35 def treeForeCast(tree, inData, modelEval=regTreeEval):
36     ‘使用回歸/模型樹進行預測 (modelEval參數指定)‘
37     
38     # 如果非樹類型,返回值。
39     if not isTree(tree): return modelEval(tree, inData)
40     
41     # 左遍歷
42     if inData[tree[‘spInd‘]] > tree[‘spVal‘]:
43         if isTree(tree[‘left‘]): return treeForeCast(tree[‘left‘], inData, modelEval)
44         else: return modelEval(tree[‘left‘], inData)
45         
46     # 右遍歷
47     else:
48         if isTree(tree[‘right‘]): return treeForeCast(tree[‘right‘], inData, modelEval)
49         else: return modelEval(tree[‘right‘], inData)
技術分享

使用方法非常簡單,將樹和要分類的樣本傳遞進去就可以了。如果是模型樹就將分類函數 treeForeCast 的第三個參數改為modelTreeEval即可。

這裏就不再演示實驗具體過程了。

回到頂部

小結

1. 選擇哪個回歸方法,得看哪個方法的相關系數高。(可使用 corrcoef 函數計算)

2. 樹的回歸和分類算法其實本質上都屬於貪心算法,不斷去尋找局部最優解。

3. 關於回歸的討論就先告一段落,接下來將進入到無監督學習部分。

cart樹回歸及其剪枝的python實現