1. 程式人生 > >CART迴歸樹--《機器學習實戰》

CART迴歸樹--《機器學習實戰》

在學習整合方法的過程中,順著思路來到CART迴歸樹,它作為GBDT的基學習器,是以均方誤差作為損失函式,找到其取極小值時的點作為切分點,將資料集劃分為左右子樹,然後繼續上面的步驟。

下面是程式碼部分,由於《機器學習實戰》書中的程式碼存在部分錯誤,下面給予修正。

# _*_ coding: UTF-8 _*_
from numpy import *
import numpy as np
import pickle
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties  # 設定字型屬性
def 
loadDataSet(fileName): ''' 讀取一個一tab鍵為分隔符的檔案,然後將每行的內容儲存成一組浮點數 ''' dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = list(map(float,curLine)) dataMat.append(fltLine) return dataMat def binSplitDataSet(dataSet, feature, value): '''
資料集切分函式 ''' mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] return mat0,mat1 #--------------迴歸樹所需子函式---------------# def regLeaf(dataSet): '''負責生成葉節點''' #當chooseBestSplit()函式確定不再對資料進行切分時,將呼叫本函式來得到葉節點的模型。 #在迴歸樹中,該模型其實就是目標變數的均值。
return np.mean(dataSet[:,-1]) def regErr(dataSet): ''' 誤差估計函式,該函式在給定的資料上計算目標變數的平方誤差,這裡直接呼叫均方差函式 ''' return var(dataSet[:,-1]) * shape(dataSet)[0]#返回總方差 #--------------迴歸樹子函式 END --------------# def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): ''' 用最佳方式切分資料集和生成相應的葉節點 ''' #ops為使用者指定引數,用於控制函式的停止時機 tolS = ops[0]; tolN = ops[1] #如果所有值相等則退出 if len(set(dataSet[:,-1].T.tolist()[0])) == 1: return None, leafType(dataSet) m,n = shape(dataSet) S = errType(dataSet) bestS = inf; bestIndex = 0; bestValue = 0 #在所有可能的特徵及其可能取值上遍歷,找到最佳的切分方式 #最佳切分也就是使得切分後能達到最低誤差的切分 for featIndex in range(n-1): for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS #如果誤差減小不大則退出 if (S - bestS) < tolS: return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) #如果切分出的資料集很小則退出 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): return None, leafType(dataSet) #提前終止條件都不滿足,返回切分特徵和特徵值 return bestIndex,bestValue def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): ''' 樹構建函式 leafType:建立葉節點的函式 errType:誤差計算函式 ops:包含樹構建所需其他引數的元組 ''' #選擇最優的劃分特徵 #如果滿足停止條件,將返回None和某類模型的值 #若構建的是迴歸樹,該模型是一個常數;如果是模型樹,其模型是一個線性方程 feat, val = chooseBestSplit(dataSet, leafType, errType, ops) if feat == None: return val # retTree = {} retTree['spInd'] = feat retTree['spVal'] = val #將資料集分為兩份,之後遞迴呼叫繼續劃分 lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree def storeTree(inputTree, filename): with open(filename, 'wb') as fw: pickle.dump(inputTree, fw) if __name__ == '__main__': myDat = loadDataSet('ex00.txt') x=[x[0] for x in myDat] y=[y[1] for y in myDat] font=FontProperties(fname=r"c:\windows\fonts\simsun.ttc",size=14) plt.figure(figsize=(8,4)) plt.scatter(x,y) plt.xlabel("x") plt.ylabel("y") plt.title(u"基於CART演算法構建迴歸樹的簡單資料集",fontproperties=font) # 在圖上輸出中文標題 plt.show() myMat = mat(myDat) retTree = createTree(myMat) print(retTree) storeTree(retTree, 'retTree.txt')
執行效果如下:

(1)輸出資料集

(2)輸出迴歸樹


修改資料集為ex0.txt,輸出結果如下:

(1)散點圖


(2)迴歸樹

{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583
, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal'
: 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}