1. 程式人生 > >Spark-MLlib例項——決策樹

Spark-MLlib例項——決策樹

Spark-MLlib例項——決策樹

通俗來說,決策樹分類的思想類似於找物件。現想象一個女孩的母親要給這個女孩介紹男朋友,於是有了下面的對話:

  1. 女兒:多大年紀了?  
  2. 母親:26。  
  3. 女兒:長的帥不帥?  
  4. 母親:挺帥的。  
  5. 女兒:收入高不?  
  6. 母親:不算很高,中等情況。  
  7. 女兒:是公務員不?  
  8. 母親:是,在稅務局上班呢。  
  9. 女兒:那好,我去見見。  



以上是決策的經典例子,用spark-mllib怎麼實現訓練與預測呢

1、首先準備測試資料集

訓練資料集 Tree1

欄位說明:

是否見面, 年齡  是否帥  收入(1 高 2 中等 0 少)  是否公務員

  1. 0,32 1 1 0  
  2. 0,25 1 2 0  
  3. 1,29 1 2 1  
  4. 1,24 1 1 0  
  5. 0,31 1 1 0  
  6. 1,35 1 2 1  
  7. 0,30 0 1 0  
  8. 0,31 1 1 0  
  9. 1,30 1 2 1  
  10. 1,21 1 1 0  
  11. 0,21 1 2 0  
  12. 1,21 1 2 1  
  13. 0,29 0 2 1  
  14. 0,29 1 0 1  
  15. 0,29 0 2 1  
  16. 1,30 1 1 0  
測試資料集 Tree2
  1. 0,32 1 2 0  
  2. 1,27 1 1 1  
  3. 1,29 1 1 0  
  4. 1,25 1 2 1  
  5. 0,23 0 2 1  

2、Spark-MLlib決策樹應用程式碼
  1. import
     org.apache.log4j.{Level, Logger}  
  2. import org.apache.spark.mllib.feature.HashingTF  
  3. import org.apache.spark.mllib.linalg.Vectors  
  4. import org.apache.spark.mllib.regression.LabeledPoint  
  5. import org.apache.spark.mllib.tree.DecisionTree  
  6. import org.apache.spark.mllib.util.MLUtils  
  7. import org.apache.spark.{SparkConf, SparkContext}  
  8. /** 
  9.   * 決策樹分類 
  10.   */
  11. object TreeDemo {  
  12.   def main(args: Array[String]) {  
  13.     val conf = new SparkConf().setAppName("DecisionTree").setMaster("local")  
  14.     val sc = new SparkContext(conf)  
  15.     Logger.getRootLogger.setLevel(Level.WARN)  
  16.     //訓練資料
  17.     val data1 = sc.textFile("data/Tree1.txt")  
  18.     //測試資料
  19.     val data2 = sc.textFile("data/Tree2.txt")  
  20.     //轉換成向量
  21.     val tree1 = data1.map { line =>  
  22.       val parts = line.split(',')  
  23.       LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))  
  24.     }  
  25.     val tree2 = data2.map { line =>  
  26.       val parts = line.split(',')  
  27.       LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))  
  28.     }  
  29.     //賦值
  30.     val (trainingData, testData) = (tree1, tree2)  
  31.     //分類
  32.     val numClasses = 2
  33.     val categoricalFeaturesInfo = Map[Int, Int]()  
  34.     val impurity = "gini"
  35.     //最大深度
  36.     val maxDepth = 5
  37.     //最大分支
  38.     val maxBins = 32
  39.     //模型訓練
  40.     val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,  
  41.       impurity, maxDepth, maxBins)  
  42.     //模型預測
  43.     val labelAndPreds = testData.map { point =>  
  44.       val prediction = model.predict(point.features)  
  45.       (point.label, prediction)  
  46.     }  
  47.     //測試值與真實值對比
  48.     val print_predict = labelAndPreds.take(15)  
  49.     println("label" + "\t" + "prediction")  
  50.     for (i <- 0 to print_predict.length - 1) {  
  51.       println(print_predict(i)._1 + "\t" + print_predict(i)._2)  
  52.     }  
  53.     //樹的錯誤率
  54.     val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()  
  55.     println("Test Error = " + testErr)  
  56.     //列印樹的判斷值
  57.     println("Learned classification tree model:\n" + model.toDebugString)  
  58.   }  
  59. }  


3、測試結果:

  1. label   prediction  
  2. 0.0 0.0  
  3. 1.0 1.0  
  4. 1.0 1.0  
  5. 1.0 1.0  
  6. 0.0 0.0  
  7. Test Error = 0.0  
  8. Learned classification tree model:  
可見真實值與預測值一致,Error為0

列印決策樹的分支值,這裡最大深度為 5 ,對應的樹結構:

  1. Learned classification tree model:  
  2. DecisionTreeModel classifier of depth 4 with 11 nodes  
  3.   If (feature 1 <= 0.0)  
  4.    Predict: 0.0  
  5.   Else (feature 1 > 0.0)  
  6.    If (feature 3 <= 0.0)  
  7.     If (feature 0 <= 30.0)  
  8.      If (feature 2 <= 1.0)  
  9.       Predict: 1.0  
  10.      Else (feature 2 > 1.0)  
  11.       Predict: 0.0  
  12.     Else (feature 0 > 30.0)  
  13.      Predict: 0.0  
  14.    Else (feature 3 > 0.0)  
  15.     If (feature 2 <= 0.0)  
  16.      Predict: 0.0  
  17.     Else (feature 2 > 0.0)  
  18.      Predict: 1.0  
可見預測出的分界值與真實一致,準確率與決策樹演算法,引數設定及訓練樣本的選擇覆蓋有關!