1. 程式人生 > >Spark平臺下的組合分類器AdaBoost

Spark平臺下的組合分類器AdaBoost

首先在github上發現了寫好的Adaboost包,可以用來測試下能否使用。

https://github.com/tizfa/sparkboost

對於Java程式需求的是JavaRDD<MultilabelPoint> 資料格式,而讀取的是RDD<labeledPoint>,轉化為JavaRDD<labeledPoint>。

所以要對於兩種資料格式進行轉換。把label,feature對應起來。

public class ClassifierTask {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("ClassifierTask").setMaster("local");

                JavaSparkContext sc = new JavaSparkContext(conf);

               // 得到常用的Sparkconf和sc, JavaSparkContext to SparkContext
                SparkContext sc1 = sc.sc();

                String inputFile = "D:\\softs\\spark-1.6.0-bin-hadoop2.6\\data\\mllib\\sample_binary_classification_data.txt";
                JavaRDD<String> StringFile = sc.textFile("D:\\softs\\spark-1.6.0-bin-hadoop2.6\\data\\mllib\\sample_libsvm_data.txt");
         
JavaRDD<LabeledPoint> FileLabeledPoint = MLUtils.loadLibSVMFile(sc1, inputFile).toJavaRDD();

               // from RDD to train model,轉換成multilabelpoint
        JavaRDD<MultilabelPoint> rdd = FileLabeledPoint.map(Row -> {
   int a = (int)Row.label();
   SparseVector b = (SparseVector)Row.features();
   int docID =0;
   int[] labels = {a};
   SparseVector feature = b;
   return new MultilabelPoint(docID, feature, labels);
   });
        //train set is 0.8, test set is 0.2,設定權重
       double[] weights = {0.8,0.2};
       
       JavaRDD<MultilabelPoint>[] data = rdd.randomSplit(weights);
       AdaBoostMHLearner learner = new AdaBoostMHLearner(sc);

               //設定分類器的各項引數
learner.setNumIterations(100);
learner.setNumDocumentsPartitions(2);
learner.setNumFeaturesPartitions(2);
learner.setNumLabelsPartitions(2); 
BoostClassifier classifier = learner.buildModel(data[0]);
   
ClassificationResults  results = classifier.classifyWithResults(sc, data[1], 1);


// Print results in a StringBuilder.
StringBuilder sb = new StringBuilder();
sb.append("**** Effectiveness\n");
sb.append(results.getCt().toString() + "\n");
sb.append("********\n");
for (int i = 0; i < results.getNumDocs(); i++) {
int docID = results.getDocuments()[i];
int[] labels = results.getLabels()[i];
int[] goldLabels = results.getGoldLabels()[i];
sb.append("DocID: " + docID + ", Labels assigned: " + Arrays.toString(labels) + ", Labels scores: " + Arrays.toString(results.getScores()[i]) + ", Gold labels: " +                             Arrays.toString(goldLabels) + "\n");
}
System.out.print(sb);
}
}