spark mllib原始碼分析之DecisionTree與GBDT
我們在前面的文章講過,在spark的實現中,樹模型的依賴鏈是GBDT-> Decision Tree-> Random Forest,前面介紹了最基礎的Random Forest的實現,在此基礎上我們介紹Decision Tree和GBDT的實現。
1. Decision Tree
1.1. DT的使用
官方給出的demo
// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
其入參除了不需要指定樹個數,其他引數與隨機森林類似,不再贅述
1.2 實現
主要的邏輯在DecisionTree.scala的run函式中
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.run(input)
rfModel.trees(0)
}
其實就是Random Forest 1棵樹的情形,同時特徵不再抽樣。
2. Gradient Boosting Decision Tree
2.1. 演算法簡介
簡稱GBDT,中文譯作梯度提升決策樹,估計沒幾個人聽過。這裡貼幾張之前介紹GBDT的PPT,簡單回顧起演算法原理,其中內容來自wikipedia和”From RankNet to LambdaRank to LambdaMAR An Overview”這篇文章。
2.1.1. 演算法原理
在這個演算法裡面,並沒有限定使用決策樹,如果使用決策樹,對應裡面的h應該是樹結構,我們以決策樹說明
1. 使用原始樣本直接訓練一棵樹
迴圈訓練
2. 計算偽殘差,實際是梯度
3. 將2中的偽殘差作為樣本的label去訓練決策樹
4. 這裡是用最優化方法計算葉子節點的輸出,而spark中直接使用的均值
5. 計算當輪模型的輸出,方法是上一輪的輸出加上本輪的預測值
6. 迴圈結束後,輸出模型
2.1.2. 以二分類為例
2.2. GBDT使用
官方demo
// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
首先初始化訓練引數boostingStrategy,然後設定其迭代次數,分類樹,樹的最大深度,離散特徵及其特徵值數,我們看下預設的引數都有哪些
/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
@Since("1.3.0")
def defaultParams(algo: Algo): BoostingStrategy = {
val treeStrategy = Strategy.defaultStrategy(algo)
treeStrategy.maxDepth = 3
algo match {
case Algo.Classification =>
treeStrategy.numClasses = 2
new BoostingStrategy(treeStrategy, LogLoss)
case Algo.Regression =>
new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
}
預設樹的最大深度為3,如果是分類,為二分類,使用LogLoss;如果是迴歸,使用SquareError,均方誤差。然後使用Strategy的預設引數
/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo Algo.Classification or Algo.Regression
*/
@Since("1.3.0")
def defaultStrategy(algo: Algo): Strategy = algo match {
case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
Strategy的預設引數也比較簡單,其意義參見之前的文章。
2.3. GBDT實現
其實現開始於GradientBoostedTrees.scala的run函式
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return a gradient boosted trees model that can be used for prediction
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}
從其註釋可以看到,spark GBDT只實現了二分類,並且二分類的class必須是0/1,其把0/1轉化成-1/+1的label,然後按回歸處理。
2.3.2. 資料結構
2.3.2.1. LogLoss
在第二頁PPT中我們給出了loss,spark使用的loss是σ=1,log前增加了係數2的情況
對應梯度變成
其中m-1指的是在第m次迭代中,使用的是m-1次的預測值。注意到我們的PPT的第四頁的γ,其實是葉子節點的預測值,是通過最優化得到的,而spark這裡使用的是Random Forest的程式碼,其impurity選擇的是variance,因此預測值是均值。
@Since("1.2.0")
override def gradient(prediction: Double, label: Double): Double = {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
override private[mllib] def computeError(prediction: Double, label: Double): Double = {
//loss
val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
SquaredError比較簡單,這裡不再囉嗦了。
2.3.1. init
將傳入的引數轉成訓練時的引數,cache predError和validatePredError,並且按treeStrategy.getCheckpointInterval(default 10)建立checkpoint。這裡程式碼比較簡單,不再贅述。
2.3.2. build the first tree
參照演算法原理的第一步,訓練了第一棵樹,並且將weight設為1,,然後計算錯誤率。呼叫了computeInitialPredictionAndError函式
/**
* :: DeveloperApi ::
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
* @param initTreeWeight: learning rate assigned to the first tree.
* @param initTree: first DecisionTreeModel.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to every sample.
*/
@Since("1.4.0")
@DeveloperApi
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
initTree: DecisionTreeModel,
loss: Loss): RDD[(Double, Double)] = {
data.map { lp =>
val pred = initTreeWeight * initTree.predict(lp.features)
val error = loss.computeError(pred, lp.label)
(pred, error)
}
}
其中預測值直接使用DT的predict來預測,error使用loss的computeError函式,我們上面有介紹。
2.3.3. 迴圈訓練
2.3.3.1. 樣本處理
對應演算法的第2步,計算梯度,並且作為label更新樣本
val data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
2.3.3.2. 訓練樹
對應演算法的第3和第4步,用第2步的樣本作為輸入,訓練決策樹
val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
2.3.3.3. 計算模型輸出
實際呼叫updatePredictionError函式,入參是原始的樣本,上一輪的錯誤率(實際包含上一輪的模型輸出),本來的決策樹,學習率和loss計算物件。
/**
* :: DeveloperApi ::
* Update a zipped predictionError RDD
* (as obtained with computeInitialPredictionAndError)
* @param data: training data.
* @param predictionAndError: predictionError RDD
* @param treeWeight: Learning rate.
* @param tree: Tree using which the prediction and error should be updated.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to each sample.
*/
@Since("1.4.0")
@DeveloperApi
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
treeWeight: Double,
tree: DecisionTreeModel,
loss: Loss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map { case (lp, (pred, error)) =>
//計算本輪模型的預測值
val newPred = pred + tree.predict(lp.features) * treeWeight
//計算本輪誤差
val newError = loss.computeError(newPred, lp.label)
//newPred是累計,包含至本輪的模型輸出
(newPred, newError)
}
}
newPredError
}
程式碼中使用到的函式我們之前都有介紹。
2.3.3.3. validation(early stop)
類似計算錯誤率,只是樣本使用validationInput,看平均誤差是否減少,如果不能使誤差減小就結束訓練,相當於出現過擬合了;如果能,就繼續訓練,並且記錄最好的模型的index。這裡一次誤差變大就結束訓練比較武斷,最好應該有一定的閾值,避免單次訓練的波動。程式碼比較簡單,就不放了。
2.3.3.4. 訓練收尾
訓練完成後,根據記錄的最優模型的index,構造GradientBoostedTreesModel。
3.結語
從上面的分析可以看到,由於spark在Random Forest特徵方面的限制,以及GBDT實現中直接使用均值作為葉子節點的輸出值,early stop等,spark在樹模型上的精度可能會差一點,實際使用的話,最好與其他實現比較後決定是否使用。
相關推薦
spark mllib原始碼分析之DecisionTree與GBDT
我們在前面的文章講過,在spark的實現中,樹模型的依賴鏈是GBDT-> Decision Tree-> Random Forest,前面介紹了最基礎的Random Forest的實現,在此基礎上我們介紹Decision Tree和GBDT的實現
spark mllib原始碼分析之二分類邏輯迴歸evaluation
在邏輯迴歸分類中,我們評價分類器好壞的主要指標有精準率(precision),召回率(recall),F-measure,AUC等,其中最常用的是AUC,它可以綜合評價分類器效能,其他的指標主要偏重一些方面。我們介紹下spark中實現的這些評價指標,便於使用sp
spark mllib原始碼分析之隨機森林(Random Forest)(二)
4. 特徵處理 這部分主要在DecisionTree.scala的findSplitsBins函式,將所有特徵封裝成Split,然後裝箱Bin。首先對split和bin的結構進行說明 4.1. 資料結構 4.1.1. Split cl
spark mllib原始碼分析之L-BFGS(一)
1. 使用 spark給出的example中涉及到LBFGS有兩個,分別是LBFGSExample.scala和LogisticRegressionWithLBFGSExample.scala,第一個是直接使用LBFGS直接訓練,需要指定一系列優化引數,優
spark mllib原始碼分析之隨機森林(Random Forest)(三)
6. 隨機森林訓練 6.1. 資料結構 6.1.1. Node 樹中的每個節點是一個Node結構 class Node @Since("1.2.0") ( @Since("1.0.0") val id: Int, @S
spark mllib原始碼分析之邏輯迴歸彈性網路ElasticNet(一)
spark在ml包中將邏輯迴歸封裝了下,同時在演算法中引入了L1和L2正則化,通過elasticNetParam來調節兩種正則化的係數,同時根據選擇的正則化,決定使用L-BFGS還是OWLQN優化,是謂Elastic Net。 1. 輔助類 我們首先介紹
Spark core原始碼分析之spark叢集的啟動(二)
2.2 Worker的啟動 org.apache.spark.deploy.worker 1 從Worker的伴生物件的main方法進入 在main方法中首先是得到一個SparkConf例項conf,然後將conf和啟動Worker傳入的引數封裝得到Wor
netty原始碼分析之-SimpleChannelInboundHandler與ChannelInboundHandlerAdapter詳解(6)
每一個Handler都一定會處理出站或者入站(也可能兩者都處理)資料,例如對於入站的Handler可能會繼承SimpleChannelInboundHandler或者ChannelInboundHandlerAdapter,而SimpleChannelIn
Spark SQL 原始碼分析之Physical Plan 到 RDD的具體實現
我們都知道一段sql,真正的執行是當你呼叫它的collect()方法才會執行Spark Job,最後計算得到RDD。 lazy val toRdd: RDD[Row] = executedPlan.execute() Spark Plan基本包含4種操作型別,即Bas
Spark MLlib原始碼分析—Word2Vec原始碼詳解
以下程式碼是我依據SparkMLlib(版本1.6)中Word2Vec原始碼改寫而來,基本算是照搬。此版Word2Vec是基於Hierarchical Softmax的Skip-gram模型的實現。 在決定讀懂原始碼前,博主建議讀者先看一下《Word2Vec_
Prometheus 實戰於原始碼分析之API與聯邦
在進行原始碼講解關於prometheus還有一些配置和使用,需要解釋一下。首先是API的使用,prometheus提供了一套HTTP的介面 curl http://localhost:9090/api/v1/query?query=go_goroutine
Spark MLlib原始碼解讀之樸素貝葉斯分類器,NaiveBayes
Spark MLlib 樸素貝葉斯NaiveBayes 原始碼分析 基本原理介紹 首先是基本的條件概率求解的公式。 P(A|B)=P(AB)P(B) 在現實生活中,我們經常會碰到已知一個條件概率,求得兩個時間交換後的概率的問題。也就是在已知P(A
Mybatis原始碼分析之Spring與Mybatis整合MapperScannerConfigurer處理過程原始碼分析
前面文章分析了這麼多關於Mybatis原始碼解析,但是我們最終使用的卻不是以前面文章的方式,編寫自己mybatis_config.xml,而是最終將配置融合在spring的配置檔案中。有了前面幾篇部落格的分析,相信這裡會容易理解些關於Mybatis的初始化及
Realm原始碼分析之copyToRealm與copyToRealmOrUpdate
createObject 在Realm原始碼分析之Writes中已經詳細追蹤過createObject的執行流程,此處不再贅述。 createObject有如下的兩個過載方法,區別是如果Model沒有指明主鍵使用前者,否則使用後者: createObjec
Spark MLlib原始碼分析—TFIDF原始碼詳解
以下程式碼是我依據SparkMLlib(版本1.6) 1、HashingTF 是使用雜湊表來儲存分詞,並計算分詞頻數(TF),生成HashMap表。在Map中,K為分詞對應索引號,V為分詞的頻數。在宣告HashingTF 時,需要設定numFeatures,該
netty原始碼分析之-EventLoop與執行緒模型(1)
執行緒模型確定來程式碼的執行方式,我們總是必須規避併發執行可能會帶來的副作用,所以理解netty所採用的併發模型的影響很重要。netty使用了被稱為事件迴圈的EventLoop來執行任務來處理在連線的生命週期內發生的事件 執行緒模型 對於Even
Mybatis深入原始碼分析之Mapper與介面繫結原理原始碼分析
緊接上篇文章:Mybatis深入原始碼分析之SqlSessionFactoryBuilder原始碼分析,這裡再來分析下,Mappe
Spark原始碼分析之Spark Shell(上)
https://www.cnblogs.com/xing901022/p/6412619.html 文中分析的spark版本為apache的spark-2.1.0-bin-hadoop2.7。 bin目錄結構: -rwxr-xr-x. 1 bigdata bigdata 1089 Dec
symfony原始碼分析之容器的生成與使用
symfony 的容器是有一個編譯過程的,框架初始化的時候會執行Symfony\Component\HttpKernel\Kernel::initializationContainer ,這個方法會對程式碼進行檢查,看是否需要生成新的容器程式碼。如果需要 Symfony 會將各個類的依賴關係通過
Spark——Streaming原始碼解析之資料的產生與匯入
此文是從思維導圖中匯出稍作調整後生成的,思維腦圖對程式碼瀏覽支援不是很好,為了更好閱讀體驗,文中涉及到的原始碼都是刪除掉不必要的程式碼後的虛擬碼,如需獲取更好閱讀體驗可下載腦圖配合閱讀: 此博文共分為四個部分: DAG定義 Job動態生成 資料的產生與匯入 容錯 資料的產生與匯入主要分為以下五個部分