Spark-MLlib的快速使用之三(樸素貝葉斯分類)
(1)描述資訊
隨機森林演算法是機器學習、計算機視覺等領域內應用極為廣泛的一個演算法,它不僅可以用來做分類,也可用來做迴歸即預測,隨機森林機由多個決策樹構成,相比於單個決策樹演算法,它分類、預測效果更好,不容易出現過度擬合的情況。
隨機森林演算法基於決策樹,在正式講解隨機森林演算法之前,先來介紹決策樹的原理。決策樹是資料探勘與機器學習領域中一種非常重要的分類器,演算法通過訓練資料來構建一棵用於分類的樹,從而對未知資料進行高效分類。舉個相親的例子來說明什麼是決策樹、如何構建一個決策樹及如何利用決策樹進行分類,某相親網站通過調查相親歷史資料發現,女孩在實際相親時有如下表現:
(2)測試資料
1 125:145 126:255 127:211 128:31 152:32 153:237 154:253 155:252 156:71 180:11 181:175 182:253 183:252 184:71 209:144 210:253 211:252 212:71 236:16 237:191 238:253 239:252 240:71 264:26 265:221 266:253 267:252 268:124 269:31 293:125 294:253 295:252 296:252 297:108 322:253 323:252 324:252 325:108 350:255 351:253 352:253 353:108 378:253 379:252 380:252 381:108 406:253 407:252 408:252 409:108 434:253 435:252 436:252 437:108 462:255 463:253 464:253 465:170 490:253 491:252 492:252 493:252 494:42 518:149 519:252 520:252 521:252 522:144 546:109 547:252 548:252 549:252 550:144 575:218 576:253 577:253 578:255 579:35 603:175 604:252 605:252 606:253 607:35 631:73 632:252 633:252 634:253 635:35 659:31 660:211 661:252 662:253 663:35
(3)測試程式碼
public static void main(String[] args) {
// $example on$
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample").setMaster("local");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Train a RandomForest model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
Integer numTrees = 3; // Use more in practice.
String featureSubsetStrategy = "auto"; // Let the algorithm choose.
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;
Integer seed = 12345;
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
seed);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
System.out.println("----------->" + predictionAndLabel.take(10));
Double testErr =
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification forest model:\n" + model.toDebugString());
// Save and load model
model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel");
RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),
"target/tmp/myRandomForestClassificationModel");
// $example off$
}
(4)測試結果
[(1.0,1.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (0.0,0.0), (1.0,1.0), (1.0,1.0), (0.0,0.0)]