Spark快速獲得CrossValidator的最佳模型參數
阿新 • • 發佈:2018-10-22
tokenizer ctp best map 一個 pip eval set alua
Spark提供了便利的Pipeline模型,可以輕松的創建自己的學習模型。
但是大部分模型都是需要提供參數的,如果不提供就是默認參數,那麽怎麽選擇參數就是一個比較常見的問題。Spark提供在org.apache.spark.ml.tuning包下提供了模型選擇器,可以替換參數然後比較模型輸出。
目前有CrossValidator和TrainValidationSplit兩種,比如一個文本情感預測模型。
Pipeline只有三步,第一步切詞,第二步HashingTF,第三步NB分類
Pipeline pipeline = new Pipeline()
.setStages( new PipelineStage[]{tokenizer, hashingTF, naiveBayes});
ParamMap[] paramMaps = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures(), new int[]{10000, 100000, 500000, 1000000})
.build();
CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator( new BinaryClassificationEvaluator())
.setEstimatorParamMaps(paramMaps);
其中HashingTF的參數選擇非常重要,我們這裏就隨便嘗試幾種,然後放在CrossValidator中去。
最後我們會獲得一個CrossValidatorModel類,這裏有兩種選擇。
第一種是自己手動獲取其中的參數,因為bestModel的參數就是我們最後選擇的參數
Pipeline bestPipeline = (Pipeline) model.bestModel().parent();
PipelineStage stage = bestPipeline.getStages()[1];
stage.extractParamMap().get(stage.getParam("numFeatures"));
這種方法可以獲得值,但是需要根據你模型情況修改獲取的位置。
如果你只是想知道最佳參數是多少,並不是需要在上下文中使用,那還有一個更簡單的方法。
修改log4j的配置,添加
log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFO
log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO
效果如下:
Spark快速獲得CrossValidator的最佳模型參數