Apache Spark MLlib學習筆記(六)MLlib決策樹類演算法原始碼解析 2
阿新 • • 發佈:2019-02-07
def trainClassifier(
input: RDD[LabeledPoint],
numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int) : DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,categoricalFeaturesInfo)
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]) : DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).run(input)
@param input Training dataset: RDD,標籤是{0, 1, ..., numClasses-1}.
@param algo :classification(分類) 或者 regression(迴歸)
@param impurity:資訊增益的計算方法,包括gini,entropy,varience。
@param maxDepth:樹的最大深度,0 代表只有根節點,1代表1個根節點,兩個葉子節點。
@param numClasses:分類的數量,預設值是2。
@param maxBins :分類屬性的最大值。
@param quantileCalculationStrategy:計算分位數演算法
@param categoricalFeaturesInfo:儲存類別/屬性鍵值對 (n -> k),特性n有K個類別,下標分別是
0: {0, 1, ..., k-1}.
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.run(input)
可以看出程式最終呼叫了RandomForest的方法,即對於spark MLlib,決策樹作為隨即森林的一個特例,即只有一棵樹,因此 rfModel.trees(0)方法中傳入的引數為0,即只有一棵樹。下面進入RandomForest原始檔,路徑是
def run(input: RDD[LabeledPoint]): RandomForestModel = {
val timer = new TimeTracker()
val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
logDebug("algo = " + strategy.algo)
logDebug("numTrees = " + numTrees)
logDebug("seed = " + seed)
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
logDebug("subsamplingRate = " + strategy.subsamplingRate)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
val withReplacement = if (numTrees > 1) true else false
val baggedInput
= BaggedPoint.convertToBaggedRDD(treeInput,
strategy.subsamplingRate, numTrees,
withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
require(maxDepth <= 30,
s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
// Max memory usage for aggregates
// TODO: Calculate memory usage more precisely.
val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
val maxMemoryPerNode = {
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
// Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
} else {
RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
require(maxMemoryPerNode <= maxMemoryUsage,
s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
" which is too small for the given features." +
s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
* The main idea here is to perform group-wise training of the decision tree nodes thus
* reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
* Each data sample is handled by a particular node (or it reaches a leaf and is not used
* in lower levels).
// Create an RDD of node Id cache.
// At first, all the rows belong to the root nodes (node Id == 1).
val nodeIdCache = if (strategy.useNodeIdCache) {
data = baggedInput,
numTrees = numTrees,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
} else {
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
val rng = new scala.util.Random()
// Allocate and queue root nodes.
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
while (nodeQueue.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
// Sanity check (should never occur):
assert(nodesForGroup.size > 0,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
// Choose node splits, and enqueue new nodes as needed.
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
logInfo("Internal timing for DecisionTree:")
// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
try {
} catch {
case e:IOException =>
logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}")
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
new RandomForestModel(strategy.algo, trees)