1. 程式人生 > >Spark快速獲得CrossValidator的最佳模型參數

Spark快速獲得CrossValidator的最佳模型參數

tokenizer ctp best map 一個 pip eval set alua

Spark提供了便利的Pipeline模型,可以輕松的創建自己的學習模型。

但是大部分模型都是需要提供參數的,如果不提供就是默認參數,那麽怎麽選擇參數就是一個比較常見的問題。Spark提供在org.apache.spark.ml.tuning包下提供了模型選擇器,可以替換參數然後比較模型輸出。

目前有CrossValidator和TrainValidationSplit兩種,比如一個文本情感預測模型。

Pipeline只有三步,第一步切詞,第二步HashingTF,第三步NB分類

Pipeline pipeline = new Pipeline()
                .setStages(
new PipelineStage[]{tokenizer, hashingTF, naiveBayes}); ParamMap[] paramMaps = new ParamGridBuilder() .addGrid(hashingTF.numFeatures(), new int[]{10000, 100000, 500000, 1000000}) .build(); CrossValidator cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(
new BinaryClassificationEvaluator()) .setEstimatorParamMaps(paramMaps);

其中HashingTF的參數選擇非常重要,我們這裏就隨便嘗試幾種,然後放在CrossValidator中去。

最後我們會獲得一個CrossValidatorModel類,這裏有兩種選擇。

第一種是自己手動獲取其中的參數,因為bestModel的參數就是我們最後選擇的參數

Pipeline bestPipeline = (Pipeline) model.bestModel().parent();
PipelineStage stage 
= bestPipeline.getStages()[1]; stage.extractParamMap().get(stage.getParam("numFeatures"));

這種方法可以獲得值,但是需要根據你模型情況修改獲取的位置。

如果你只是想知道最佳參數是多少,並不是需要在上下文中使用,那還有一個更簡單的方法。

修改log4j的配置,添加

log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFO
log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO

效果如下:

技術分享圖片

Spark快速獲得CrossValidator的最佳模型參數