1. 程式人生 > >Spark ML離線訓練模型用於線上預測

Spark ML離線訓練模型用於線上預測

最近公司有需求需要將離線訓練好的演算法模型應用到線上去實時預測,線上預測不考慮feature加工的情況下,經調研,發現jpmml-sparkml+jpmml-evaluator的方式可以滿足條件。不過使用時需要注意該框架是AGPL-3.0協議。

方案:spark ml + jpmml-sparkml + jpmml-evaluator

Spark離線訓練Random Forest模型並儲存為pmml格式:

import java.io.FileOutputStream
import javax.xml.transform.stream.StreamResult
import com.jd.risk.utils.HadoopFileUtil
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.examples.ml.DecisionTreeExample
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.jpmml.model.{JAXBUtil, MetroJAXBUtil}
import org.jpmml.sparkml.ConverterUtil
import scopt.OptionParser
import scala.collection.mutable
import scala.language.reflectiveCalls
/**
  * Created by sjmei on 2017/01/19.
  */
object RandomForestPMMLTask {
  case class Params(
      input: String = null,
      modelDir: String = null,
      taskType:String = "train",
      testInput: String = "",
      dataFormat: String = "libsvm",
      algo: String = "classification",
      maxDepth: Int = 4,
      maxBins: Int = 32,
      minInstancesPerNode: Int = 1,
      minInfoGain: Double = 0.0,
      numTrees: Int = 5,
      featureSubsetStrategy: String = "auto",
      fracTest: Double = 0.2,
      cacheNodeIds: Boolean = false,
      checkpointDir: Option[String] = None,
      checkpointInterval: Int = 10) extends AbstractParams[Params]
  def main(args: Array[String]) {
    val defaultParams = Params()
    val parser = new OptionParser[Params]("RandomForestExample") {
      head("RandomForestExample: an example random forest app.")
      opt[String]("algo")
        .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
        .action((x, c) => c.copy(algo = x))
      opt[String]("taskType")
        .text(s"modelType, default: ${defaultParams.taskType}")
        .action((x, c) => c.copy(taskType = x))
      opt[Int]("maxDepth")
        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
        .action((x, c) => c.copy(maxDepth = x))
      opt[Int]("maxBins")
        .text(s"max number of bins, default: ${defaultParams.maxBins}")
        .action((x, c) => c.copy(maxBins = x))
      opt[Int]("minInstancesPerNode")
        .text(s"min number of instances required at child nodes to create the parent split," +
        s" default: ${defaultParams.minInstancesPerNode}")
        .action((x, c) => c.copy(minInstancesPerNode = x))
      opt[Double]("minInfoGain")
        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
        .action((x, c) => c.copy(minInfoGain = x))
      opt[Int]("numTrees")
        .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
        .action((x, c) => c.copy(numTrees = x))
      opt[String]("featureSubsetStrategy")
        .text(s"number of features to use per node (supported:" +
        s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
        s" default: ${defaultParams.numTrees}")
        .action((x, c) => c.copy(featureSubsetStrategy = x))
      opt[Double]("fracTest")
        .text(s"fraction of data to hold out for testing. If given option testInput, " +
        s"this option is ignored. default: ${defaultParams.fracTest}")
        .action((x, c) => c.copy(fracTest = x))
      opt[Boolean]("cacheNodeIds")
        .text(s"whether to use node Id cache during training, " +
        s"default: ${defaultParams.cacheNodeIds}")
        .action((x, c) => c.copy(cacheNodeIds = x))
      opt[String]("checkpointDir")
        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
        s"default: ${
          defaultParams.checkpointDir match {
            case Some(strVal) => strVal
            case None => "None"
          }
        }")
        .action((x, c) => c.copy(checkpointDir = Some(x)))
      opt[Int]("checkpointInterval")
        .text(s"how often to checkpoint the node Id cache, " +
        s"default: ${defaultParams.checkpointInterval}")
        .action((x, c) => c.copy(checkpointInterval = x))
      opt[String]("testInput")
        .text(s"input path to test dataset. If given, option fracTest is ignored." +
        s" default: ${defaultParams.testInput}")
        .action((x, c) => c.copy(testInput = x))
      opt[String]("dataFormat")
        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
        .action((x, c) => c.copy(dataFormat = x))
      arg[String]("<input>")
        .text("input path to labeled examples")
        .required()
        .action((x, c) => c.copy(input = x))
      arg[String]("<modelDir>")
        .text("modelDir path to labeled examples")
        .required()
        .action((x, c) => c.copy(modelDir = x))
      checkConfig { params =>
        if (params.fracTest < 0 || params.fracTest >= 1) {
          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
        } else {
          success
        }
      }
    }
    parser.parse(args, defaultParams) match {
      case Some(params) => {
        if(params.taskType.equalsIgnoreCase("train")){
          train(params)
        }
      }
      case _ => sys.exit(1)
    }
  }
  def train(params: Params): Unit = {
    val spark = SparkSession
      .builder
        .master("local")
      .appName(s"RandomForestExample with $params")
      .getOrCreate()
    params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir)
    val algo = params.algo.toLowerCase
    println(s"RandomForestExample with parameters:\n$params")
    // Load training and test data and cache it.
    val (training: DataFrame, test: DataFrame) = AlgoUtils.loadMaliceDataFrame(spark.sparkContext, params.input, params.fracTest)
    // Set up Pipeline.
    val stages = new mutable.ArrayBuffer[PipelineStage]()
    // (1) For classification, re-index classes.
    val labelColName = if (algo == "classification") "indexedLabel" else "label"
    if (algo == "classification") {
      val labelIndexer = new StringIndexer()
        .setInputCol("label")
        .setOutputCol(labelColName)
      stages += labelIndexer
    }
    val vectorAssember = new VectorAssembler()
    vectorAssember.setInputCols(Array("degree","tcNum","pageRank","commVertexNum","normQ","gtRate","eqRate","ltRate"))
    vectorAssember.setOutputCol("features")
    stages += vectorAssember
    // (2) Identify categorical features using VectorIndexer.
    //     Features with more than maxCategories values will be treated as continuous.
    val featuresIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(10)
    stages += featuresIndexer
    // (3) Learn Random Forest.
    val dt = algo match {
      case "classification" =>
        new RandomForestClassifier()
          .setFeaturesCol("features")
          .setLabelCol(labelColName)
          .setMaxDepth(params.maxDepth)
          .setMaxBins(params.maxBins)
          .setMinInstancesPerNode(params.minInstancesPerNode)
          .setMinInfoGain(params.minInfoGain)
          .setCacheNodeIds(params.cacheNodeIds)
          .setCheckpointInterval(params.checkpointInterval)
          .setFeatureSubsetStrategy(params.featureSubsetStrategy)
          .setNumTrees(params.numTrees)
      case "regression" =>
        new RandomForestRegressor()
          .setFeaturesCol("features")
          .setLabelCol(labelColName)
          .setMaxDepth(params.maxDepth)
          .setMaxBins(params.maxBins)
          .setMinInstancesPerNode(params.minInstancesPerNode)
          .setMinInfoGain(params.minInfoGain)
          .setCacheNodeIds(params.cacheNodeIds)
          .setCheckpointInterval(params.checkpointInterval)
          .setFeatureSubsetStrategy(params.featureSubsetStrategy)
          .setNumTrees(params.numTrees)
      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    }
    stages += dt
    val pipeline = new Pipeline().setStages(stages.toArray)
    // Fit the Pipeline.
    val startTime = System.nanoTime()
    val pipelineModel = pipeline.fit(training)
    val elapsedTime = (System.nanoTime() - startTime) / 1e9
    println(s"Training time: $elapsedTime seconds")
    val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
    /**
      * write model pmml format to hdfs
      */
    val modelPmmlPath = params.modelDir
    val pmml = ConverterUtil.toPMML(training.schema, pipelineModel);
    val conf = new Configuration();
    HadoopFileUtil.deleteFile(modelPmmlPath)
    val path = new Path(modelPmmlPath);
    val fs = path.getFileSystem(conf);
    val out = fs.create(path);
    MetroJAXBUtil.marshalPMML(pmml, out);
    MetroJAXBUtil.marshalPMML(pmml, new FileOutputStream(modelPmmlPath));
    JAXBUtil.marshalPMML(pmml, new StreamResult(System.out));
    val predictions = pipelineModel.transform(training)
    // Get the trained Random Forest from the fitted PipelineModel.
    algo match {
      case "classification" =>
        if (rfModel.totalNumNodes < 30) {
          println(rfModel.toDebugString) // Print full model.
        } else {
          println(rfModel) // Print model summary.
        }
      case "regression" =>
        val rfrModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
        if (rfrModel.totalNumNodes < 30) {
          println(rfrModel.toDebugString) // Print full model.
        } else {
          println(rfrModel) // Print model summary.
        }
      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    }
    // Evaluate model on training, test data.
    algo match {
      case "classification" =>
        println("Training data results:")
        DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
        val evaluator = new MulticlassClassificationEvaluator()
          .setLabelCol("indexedLabel")
          .setPredictionCol("prediction")
          .setMetricName("accuracy")
        val accuracy = evaluator.evaluate(predictions)
        println("Test Error = " + (1.0 - accuracy))
      case "regression" =>
        println("Training data results:")
        DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
      case _ =>
        throw new IllegalArgumentException("Algo ${params.algo} not supported.")
    }
    predictions.printSchema()
    predictions.select("label","prediction","probability").show(10)
    spark.stop()
  }
}

jpmml-evaluator實現線上實時預測:

/**
 * Created by sjmei on 2017/1/19.
 */
public class PrdictScore {
    public static void main(String[] args) throws Exception {
        PMML pmml = readPMML(new File("data/pmmlmodel/rf.pmml"));
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
//        System.out.println(pmml.getModels().get(0));
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
//        ModelEvaluator evaluator = new MiningModelEvaluator(pmml);
        evaluator.verify();
        List<InputField> inputFields = evaluator.getInputFields();
        InputStream is = new FileInputStream(new File("data/train.txt"));
        BufferedReader br = new BufferedReader(new InputStreamReader(is));
        String line;
        int diffDelta = 0;
        int sameDelta = 0;
        while((line = br.readLine()) != null) {
            String[] splits = line.split("\t",-1);
            double targetMs = transToDouble(splits[14]);
            double risk_value = transToDouble(splits[2]);
            double label = 0.0;
            if(targetMs==1.0 && risk_value >5.0d){
                label = 1.0;
            }
            LinkedHashMap<FieldName, FieldValue> arguments = readArgumentsFromLine(splits, inputFields);
            Map<FieldName, ?> results = evaluator.evaluate(arguments);
            List<TargetField> targetFields = evaluator.getTargetFields();
            for(TargetField targetField : targetFields){
                FieldName targetFieldName = targetField.getName();
                Object targetFieldValue = results.get(targetFieldName);
                ProbabilityDistribution nodeMap = (ProbabilityDistribution)targetFieldValue;
                Object result = nodeMap.getResult();
                if(label == Double.valueOf(result.toString())){
                    sameDelta +=1;
                }else{
                    diffDelta +=1;
                }
            }
        }
        System.out.println("acc count:"+sameDelta);
        System.out.println("error count:"+diffDelta);
        System.out.println("acc rate:"+(sameDelta*1.0d/(sameDelta+diffDelta)));
    }
    /**
     * 從檔案中讀取pmml模型檔案
     * @param file
     * @return
     * @throws Exception
     */
    public static PMML readPMML(File file) throws Exception {
        InputStream is = new FileInputStream(file);
        return PMMLUtil.unmarshal(is);
    }
    /**
     * 構造模型輸入特徵欄位
     * @param splits
     * @param inputFields
     * @return
     */
    public static LinkedHashMap<FieldName, FieldValue> readArgumentsFromLine(String[] splits, List<InputField> inputFields) {
        List<Double> lists = new ArrayList<Double>();
        lists.add(Double.valueOf(splits[3]));
        lists.add(Double.valueOf(splits[4]));
        lists.add(Double.valueOf(splits[5]));
        lists.add(Double.valueOf(splits[7]));
        lists.add(Double.valueOf(splits[8]));
        lists.add(Double.valueOf(splits[9]));
        lists.add(Double.valueOf(splits[10]));
        lists.add(Double.valueOf(splits[11]));
        LinkedHashMap<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        int i = 0;
        for(InputField inputField : inputFields){
            FieldName inputFieldName = inputField.getName();
            // The raw (ie. user-supplied) value could be any Java primitive value
            Object rawValue = lists.get(i);
            // The raw value is passed through: 1) outlier treatment, 2) missing value treatment, 3) invalid value treatment and 4) type conversion
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
            i+=1;
        }
        return arguments;
    }
    public static Double transToDouble(String label) {
        try {
            return Double.valueOf(label);
        }catch (Exception e){
            return Double.valueOf(0);
        }
    }
}