1. 程式人生 > >Scikit-Learn與迴歸樹

Scikit-Learn與迴歸樹

迴歸演算法原理

CART(Classification and Regression Tree)演算法是目前決策樹演算法中最為成熟的一類演算法,應用範圍也比較廣泛。它既可以用於分類。
西方預測理論一般都是基於迴歸的,CART是一種通過決策樹方法實現迴歸的演算法,它具有很多其他全域性迴歸演算法不具有的特性。
在建立迴歸模型時,樣本的取值分為觀察值和輸出值兩種,觀察值和輸出值都是連續的,不像分類函式那樣有分類標籤,只有根據資料集的資料特徵來建立一個預測的模型,反映曲線的變化趨勢。在這種情況下,原有分類樹的最優劃分規則就不再起作用了。在預測中,CART使用最小剩餘方差(Squared Residuals Minimization)來判定迴歸樹的最優劃分,這個準則期望劃分之後的子樹與樣本點的誤差方差最小

。這樣決策樹將資料集劃分成很多子模型資料,然後利用線性迴歸技術來建模。如果每次切分後的資料子集仍然難以擬合,就繼續切分。在這種切分方式下創建出的預測樹,每個葉子節點都是一個線性迴歸模型。這些線性迴歸模型反映了樣本集合(觀測集合)中蘊含的模式,也被稱為模型樹。因此,CART不僅支援正體預測,也支援區域性模式的預測,並有能力從整體中找到模式,或根據模式組合成一個整體。整體與模式之間的相互結合,對於預測分析有重要價值。因此CART決策樹演算法在預測中的應用非常廣泛。
下面介紹CART的演算法流程:
(1)決策樹主函式:決策樹的主函式是一個遞迴函式。該函式的主要功能是按照CART的規則生長出決策樹的每個分支節點,並根據終止條件結束演算法。
a.輸入需要分類的資料集和類別標籤。
b.使用最小剩餘方差判定迴歸樹的最優劃分,並建立特徵的劃分節點——最小剩餘方差子函式。
c.在劃分節點劃分資料集為兩部分——二分資料集子函式。
d.根據二分資料的結果構建出新的左右節點,作為樹生長出的兩個分支。
e.檢驗是否符合遞迴的終止條件。
f.將劃分的新節點包含的資料集和類別標籤作為輸入,遞迴執行上述步驟。
(2)使用最小剩餘方差子函式,計算資料集各列的最優劃分方差、劃分列、劃分值
(3)二分資料集:根據給定的分隔列和分隔值將資料集一分為二,分別返回。

最小剩餘方差法

在迴歸樹中,資料集均為連續性。連續資料的處理方法與離散資料不同,離散資料是按每個特徵的取值劃分,而連續特徵則要計算出最優劃分點。但在連續資料集上計算線性相關度非常簡單,演算法思想來源於最小二乘法。
最小剩餘方差法,首先求取劃分資料列的均值和總方差。總方差的計算方法有兩種
求取均值std,計算每個資料點與std的方差,然後將n個點求和
求取方差var,然後var_sum = var*n,n為資料集資料數目。
那麼,每次最佳分支特徵的選取過程如下。
(1)先令最佳方差為無限大 bestVar = inf。
(2)此次遍歷所有特徵列及每個特徵列的所有樣本點(這是一個二迴圈),在每個樣本點上二分資料集。
(3)計算二分資料集後的總方差currentVar,如果currentVar < bestVar,則bestVar = currentVar。
返回計算的最優分支特徵列、分支特徵值(連續特徵則為劃分點的值)以及左右分支子資料集到主程式。

模型樹

使用CART進行預測是把葉子節點設定為一系列的分段線性函式,這些分段線性函式是對源資料曲線的一種模擬,每個線性函式都被稱為一顆模型樹。模型樹具有很多優秀的性質,它包含了如下特徵。
一般而言,樣本總體的重複性不會很高,但區域性模式經常重複,也就是所說的歷史不會簡單的重複,但會重演。模型比總體對未來的預測而言更有用。
模型給出了資料的範圍,它可能是一個時間範圍,也可能是一個空間範圍;而且模型還給出了變化的趨勢,可以是曲線,也可以是直線,這依賴於使用的迴歸演算法。這些因素使模型具有很強的可解釋性。
傳統的迴歸方法,無論是線性迴歸還是非線性迴歸,都不如模型樹包含的資訊豐富,因此模型樹具有更高的預測準確度。

Scikit-Learn實現

#!/usr/bin/python
# created by lixin 20161118
import numpy as np
from numpy import *
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt


def plotfigure(X,X_test,y,yp):
        plt.figure()
        plt.scatter(X,y,c="k",label="data")
        plt.plot(X_test,yp,c="r",label="max_depth=5",linewidth=2)
        plt.xlabel("data")
        plt.ylabel("target")
        plt.title("Decision Tree Regression")
        plt.legend(loc='upper right')
        plt.show()
        #plt.savefig('./res.png', format='png')


x = np.linspace(-5,5,200)
siny = np.sin(x)
X = mat(x).T
y = siny + np.random.rand(1,len(siny))*1.5
y = y.tolist()[0]
clf = DecisionTreeRegressor(max_depth=4)
clf.fit(X,y)

X_test = np.arange(-5.0,5.0,0.05)[:,np.newaxi
yp = clf.predict(X_test)

plotfigure(X,X_test,y,yp)

圖1