1. 程式人生 > >Spark MLlib之決策樹(DecisioinTree)

Spark MLlib之決策樹(DecisioinTree)

程式碼:

/**
 * Created by hadoop on 16-7-3.
 */

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.tree.DecisionTree
//import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

object DT {
  def main (args: Array[String]){

    val conf =  new SparkConf().setMaster("local").setAppName("DecisonTree")
//    val conf =  new SparkConf().setMaster("spark://192.168.0.100:7077").setAppName("DecisonTree")
    val sc = new SparkContext(conf)

    // Load and parse the data file.
//    val data = MLUtils.loadLibSVMFile(sc, "/home/hadoop/桌面/kdd_split2.txt")
    val data = MLUtils.loadLibSVMFile(sc, "hdfs://192.168.0.100:9000/spark/dt/kdd_split2.txt")

    // Split the data into training and test sets (30% held out for testing)
    val splits = data.randomSplit(Array(0.7, 0.3))
    val (trainingData, testData) = (splits(0), splits(1))

    // Train a DecisionTree model.
    //  Empty categoricalFeaturesInfo indicates all features are continuous.
    val numClasses = 5   //***********************分類數目
    val categoricalFeaturesInfo = Map[Int, Int]()//設定輸入資料的格式
    val impurity = "gini"  //設定資訊增益計算方式,這裡採用gini不純度
    val maxDepth = 5  //設定樹的高度
    val maxBins = 32  //設定分裂資料集

    //建立模型
    val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
      impurity, maxDepth, maxBins)

    // Evaluate model on test instances and compute test error
    val labelAndPreds = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }

    val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
    println("Test Error = " + testErr)
    println("Learned classification tree model:\n" + model.toDebugString)

    // Save and load model
    //    model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
    //    val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
  }
}

輸入資料格式:

label index1:value1 index2:value2 ...

執行結果:

Test Error = 1.04026647139573E-4

Learned classification tree model:
DecisionTreeModel classifier of depth 3 with 7 nodes
  If (feature 0 <= 2.0)
   Predict: 1.0
  Else (feature 0 > 2.0)
   If (feature 0 <= 3.0)
    If (feature 1 <= 2.0)
     Predict: 4.0
    Else (feature 1 > 2.0)
     Predict: 3.0
   Else (feature 0 > 3.0)

    Predict: 2.0

..............