1. 程式人生 > >xgboost之spark上執行-scala介面

xgboost之spark上執行-scala介面

package com.meituan.spark_xgboost
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.sql.{ SparkSession, Row }
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
object XgboostR {


  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
    val spark = SparkSession.builder.master("local").appName("example").
      config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").
      config("spark.sql.shuffle.partitions", "20").getOrCreate()
    spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"
  val trainString = "agaricus.txt.train"
  val testString = "agaricus.txt.test"
    val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString)
    val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString)
    val traindata = train.map { x =>
      val f = x.features.toArray
      val v = x.label
      LabeledPoint(v, Vectors.dense(f))
    }
    val testdata = test.map { x =>
      val f = x.features.toArray
      val v = x.label
       Vectors.dense(f)
    }
    

    val numRound = 15
    
     //"objective" -> "reg:linear", //定義學習任務及相應的學習目標
      //"eval_metric" -> "rmse", //校驗資料所需要的評價指標  用於做迴歸
    
    val paramMap = List(
      "eta" -> 1f,
      "max_depth" ->5, //數的最大深度。預設值為6 ,取值範圍為:[1,∞] 
      "silent" -> 1, //取0時表示打印出執行時資訊,取1時表示以緘默方式執行,不列印執行時資訊。預設值為0 
      "objective" -> "binary:logistic", //定義學習任務及相應的學習目標
      "lambda"->2.5,
      "nthread" -> 1 //XGBoost執行時的執行緒數。預設值是當前系統可以獲得的最大執行緒數
      ).toMap
    println(paramMap)
    

    val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN)
    print("sucess")
 
    val result=model.predict(testdata)
    result.take(10).foreach(println)
    spark.stop();
   
  }

}