Spark-MLlib例項——決策樹
阿新 • • 發佈:2019-01-25
Spark-MLlib例項——決策樹
通俗來說,決策樹分類的思想類似於找物件。現想象一個女孩的母親要給這個女孩介紹男朋友,於是有了下面的對話:
- 女兒:多大年紀了?
- 母親:26。
- 女兒:長的帥不帥?
- 母親:挺帥的。
- 女兒:收入高不?
- 母親:不算很高,中等情況。
- 女兒:是公務員不?
- 母親:是,在稅務局上班呢。
- 女兒:那好,我去見見。
以上是決策的經典例子,用spark-mllib怎麼實現訓練與預測呢
1、首先準備測試資料集
訓練資料集 Tree1
欄位說明:
是否見面, 年齡 是否帥 收入(1 高 2 中等 0 少) 是否公務員
- 0,32 1 1 0
- 0,25 1 2 0
- 1,29 1 2 1
- 1,24 1 1 0
- 0,31 1 1 0
- 1,35 1 2 1
- 0,30 0 1 0
- 0,31 1 1 0
- 1,30 1 2 1
- 1,21 1 1 0
- 0,21 1 2 0
- 1,21 1 2 1
- 0,29 0 2 1
- 0,29 1 0 1
- 0,29 0 2 1
- 1,30 1 1 0
- 0,32 1 2 0
- 1,27 1 1 1
- 1,29 1 1 0
- 1,25 1 2 1
- 0,23 0 2 1
2、Spark-MLlib決策樹應用程式碼
-
import
- import org.apache.spark.mllib.feature.HashingTF
- import org.apache.spark.mllib.linalg.Vectors
- import org.apache.spark.mllib.regression.LabeledPoint
- import org.apache.spark.mllib.tree.DecisionTree
- import org.apache.spark.mllib.util.MLUtils
-
import org.apache.spark.{SparkConf, SparkContext}
- /**
- * 決策樹分類
- */
- object TreeDemo {
- def main(args: Array[String]) {
- val conf = new SparkConf().setAppName("DecisionTree").setMaster("local")
- val sc = new SparkContext(conf)
- Logger.getRootLogger.setLevel(Level.WARN)
- //訓練資料
- val data1 = sc.textFile("data/Tree1.txt")
- //測試資料
- val data2 = sc.textFile("data/Tree2.txt")
- //轉換成向量
- val tree1 = data1.map { line =>
- val parts = line.split(',')
- LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
- }
- val tree2 = data2.map { line =>
- val parts = line.split(',')
- LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
- }
- //賦值
- val (trainingData, testData) = (tree1, tree2)
- //分類
- val numClasses = 2
- val categoricalFeaturesInfo = Map[Int, Int]()
- val impurity = "gini"
- //最大深度
- val maxDepth = 5
- //最大分支
- val maxBins = 32
- //模型訓練
- val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
- impurity, maxDepth, maxBins)
- //模型預測
- val labelAndPreds = testData.map { point =>
- val prediction = model.predict(point.features)
- (point.label, prediction)
- }
- //測試值與真實值對比
- val print_predict = labelAndPreds.take(15)
- println("label" + "\t" + "prediction")
- for (i <- 0 to print_predict.length - 1) {
- println(print_predict(i)._1 + "\t" + print_predict(i)._2)
- }
- //樹的錯誤率
- val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
- println("Test Error = " + testErr)
- //列印樹的判斷值
- println("Learned classification tree model:\n" + model.toDebugString)
- }
- }
3、測試結果:
- label prediction
- 0.0 0.0
- 1.0 1.0
- 1.0 1.0
- 1.0 1.0
- 0.0 0.0
- Test Error = 0.0
- Learned classification tree model:
列印決策樹的分支值,這裡最大深度為 5 ,對應的樹結構:
- Learned classification tree model:
- DecisionTreeModel classifier of depth 4 with 11 nodes
- If (feature 1 <= 0.0)
- Predict: 0.0
- Else (feature 1 > 0.0)
- If (feature 3 <= 0.0)
- If (feature 0 <= 30.0)
- If (feature 2 <= 1.0)
- Predict: 1.0
- Else (feature 2 > 1.0)
- Predict: 0.0
- Else (feature 0 > 30.0)
- Predict: 0.0
- Else (feature 3 > 0.0)
- If (feature 2 <= 0.0)
- Predict: 0.0
- Else (feature 2 > 0.0)
- Predict: 1.0