1. 程式人生 > >基於spark mllib的gbt演算法例項

基於spark mllib的gbt演算法例項

背景:公司需要使用spark mllib進行預測,基於這個需求,使用spark mllib自帶的gbm進行預測。

程式碼1:

部落格
學院
下載
圖文課
論壇
APP
問答
商城
VIP會員
活動
招聘
ITeye
GitChat

搜CSDN
寫部落格賺零錢傳資源

關注和收藏在這裡
Markdown編輯器
富文字編輯器
檢視主頁
內容
文章管理
專欄管理
評論管理
個人分類管理
Chat快問 new
部落格搬家
設定
部落格設定
欄目管理
CSDN部落格QQ交流群


掃一掃二維碼
或點選這裡加入群聊


輸入文章標題

文章標籤:
新增標籤
最多新增5個標籤

個人分類:
新增新分類
文章型別:
 *
部落格分類:
 *
私密文章:
  


import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql._
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
  *  使用spark自帶的演算法
  *  資料集為my_train.csv和my_test.csv
  *  這是個穩定的版本  只有測試成功再往裡面加東西
  *
  */
object myCallXGBoost {
  Logger.getLogger("org").setLevel(Level.WARN)

  def main(args: Array[String]): Unit = {

    //val inputPath = args(0)
    val inputPath = "data"
    print("*******************"+inputPath)
    // create SparkSession
    val spark = SparkSession
      .builder()
      .appName("myCallXGBoost")
      .config("spark.executor.memory", "2G")
      .config("spark.executor.cores", "4")
      .config("hive.metastore.uris","thrift://xxxxxxxxxxxx")
      .config("spark.driver.memory", "1G")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("spark.default.parallelism", "4")
      .enableHiveSupport()
      //.master("local[*]")
      .getOrCreate()


    //step1 資料準備工作
    //從csv中讀取資料
    //val myTrainCsv = spark.read.option("header", "true").option("inferSchema", true).csv(  inputPath+"/my_train.csv")
    //val myTestCsv = spark.read.option("header", "true").option("inferSchema", true).csv(  inputPath+"/my_test.csv")

    val myTrainCsv = spark.sql("select * from dm_analysis.lsm_xgboost_train")
    val myTestCsv = spark.sql("select * from dm_analysis.lsm_xgboost_test")

    myTrainCsv.show(10)

    // 動態資料型別轉化 將any型別轉化為double
    def toDoubleDynamic(x: Any) = x match {
      case s: String => s.toDouble
      case jn: java.lang.Number => jn.doubleValue()
      case _ => throw new ClassCastException("cannot cast to double")
    }

    import spark.implicits._
    
    //這裡檢視到資料應該已經全部轉換成double型別了
    myTrainCsv.printSchema()

    //這是一個比較完整的版本  需要將features的所有行新增進來
    //直接使用row  每行來給程式賦值
    //需要提前將資料處理成 (label,features...)的格式
    val mydata = myTrainCsv.drop("_c0").map{row =>
      val row_len = row.length
      var myList = new Array[Double](row_len-1)
      for(i<- 1 to (row_len-1)){
        myList(i-1) = toDoubleDynamic(row(i))
      }
      val features = Vectors.dense(myList)
      LabeledPoint(toDoubleDynamic(row(0)), features)
    }

    mydata.show(10)
    val splits = mydata.randomSplit(Array(0.7, 0.3))
    val (trainingData, testData) = (splits(0), splits(1))


    //step2 準備訓練模型
    val boostingStrategy = BoostingStrategy.defaultParams("Regression")
    boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
    boostingStrategy.treeStrategy.maxDepth = 5
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

    val model = GradientBoostedTrees.train(trainingData.rdd, boostingStrategy)

    // Evaluate model on test instances and compute test error
    val labelsAndPredictions = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }
    val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}
    println(s"Test Mean Squared Error = $testMSE")
    println(s"Learned regression GBT model:\n ${model.toDebugString}")

    spark.stop()
  }
}

 

程式碼2: