Spark隨機森林RandomForest
位於ml/tree/impl/目錄下。mllib目錄下的隨機森林演算法也是呼叫的ml下的RandomForest。ml是mllib的最新實現,將來是要替換掉mllib庫的。
-
- RandomForest核心程式碼
- train方法
- RandomForest核心程式碼
每次迭代將要計算的node推入堆疊,選擇參與計算的抽樣資料,計算該節點,迴圈該過程。
while (nodeStack.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage,
// Sanity check (should never occur):
assert(nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
// Only send trees to worker if they contain nodes being split this iteration.
val topNodesForGroup: Map[Int, LearningNode] =
nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
timer.stop("findBestSplits")
}
-
- RandomForest演算法
- training
- RandomForest演算法
nodesForGroup:本次等待處理的節點集合。
topNodesForGroup:nodesForGroup所對應的每顆樹的根節點。
def run(
input: RDD[LabeledPoint],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
parentUID: Option[String] = None): Array[DecisionTreeModel]
run方法返回DecisionTreemodel陣列,每個成員是一個決策樹,森林對每個決策樹預測值加權得到最終預測結果。
迴圈處理節點:
(1)RandomForest.selectNodesToSplit
(2)RandomForest.findBestSplits
直到所有nodes都處理完畢,則迴圈結束,開始構造決策樹模型,建立DecisionTreeClassificationModel。
所以這裡最關鍵的是下面兩個方法:
(1)RandomForest.selectNodesToSplit
(2)RandomForest.findBestSplits
-
-
- selectNodesToSplit
-
選擇進行切分的節點。根據記憶體等狀態選擇本次切分的節點集合。返回(NodesForGroup,TreeToNodeToIndexInfo)。該方法的作用就是檢查記憶體是否夠用,在記憶體足夠的情況下其實可以忽略該函式。
森林的每個樹頂點儲存在stack中,該方法從此stack中找出可以進行切分的節點,然後呼叫findBestSplits方法構造決策樹。stack中的元素是動態變化的。
資料結構:
NodesForGroup:HashMap[Int, mutable.ArrayBuffer[LearningNode]]
key是treeIndex,value是node列表,表示屬於該tree的node列表。
TreeToNodeToIndexInfo:HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]
key是treeIndex。
value是HashMap,其中key是nodeId,value是nodeIndexInfo(有featureSubset屬性和本次group內的node數目)。由selectNodesToSplit方法建立該物件。featureSubset就是本節點需要處理的特徵集合(是所有特徵的子集)。
-
-
- findBestSplits
-
隨機森林的【主函式】,找到最好切分。
重點分析:
/**
* Given a group of nodes, this finds the best split for each node.
*
* @param input Training data: RDD of [[TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodesForGroup For each tree in group, tree index -> root node.
* Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
* @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
* where nodeIndexInfo stores the index in the group and the
* feature subsets (if using feature subsets).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param nodeStack Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
* for a corresponding tree. This is used to prevent the need
* to pass the entire tree to the executors during
* the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
topNodesForGroup: Map[Int, LearningNode],
nodesForGroup: Map[Int, Array[LearningNode]],
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
nodeStack: mutable.Stack[(Int, LearningNode)],
timer: TimeTracker = new TimeTracker,
nodeIdCache: Option[NodeIdCache] = None): Unit = {
。。。
}
尋找最優切分的函式。
為簡化程式碼分析,忽略程式碼中優化部分(入cache機制等)。
-
-
- findSplits
-
找出splits,供選擇最優分解特徵值演算法使用。
findSplitsBySorting:實際完成findSplits功能。
-
-
- binsToBestSplit
-
也是重點方法。
尋找當前node的最優特徵和特徵值,findBestSplits會呼叫到。
包含兩層迴圈,一是特徵迴圈,內部再巢狀該特徵的特徵值增益迴圈計算。最後找出最優解。
步驟:
首先獲取要spit的節點的level。獲取node增益狀態。
過濾合法的split,如果某特feature的split為空,則忽略。
/**
* Find the best split for a node.
*
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
node: LearningNode): (Split, ImpurityStats) = {
。。。
}
-
-
- calculateImpurityStats
-
計算節點左右子數的增益或者熵。
calculateImpurityStats
gain(增益)= 父node的impurity-左子數的impurity*權重-右子數的impurity*權重。
-
-
- extractMultiClassCategories
-
從離散型數值抽取出多個classLabel,和findSplitsForContinuousFeature對應。
返回離散的分割類別。
-
-
- findSplitsForContinuousFeature
-
對連續特徵抽取分割線,比如等分劃分特徵最小值和最大值之間的距離,劃分成N個split,每個split包含一個合理劃分連續數值的分割點,分割點是一個double數值。
主要輸入引數:每條記錄的對應feature值的陣列。
返回各個分割的閾值。
-
-
- aggregateSizeForNode
-
計算每個node的統計彙總維度,對於分類模型,總的統計維度=分類類別數*總的bin數(也就是每個特徵的可列舉數目)。
-
- 決策樹:DecisionTreeClassifier
單個決策樹,構造隨機森林的引數,設定子樹的數目為1,然後呼叫隨機森林演算法RandomForest生成決策森林,返回第一個節點。
-
- GBT分類
梯度提升決策樹演算法。