Spark平臺下的組合分類器AdaBoost
首先在github上發現了寫好的Adaboost包,可以用來測試下能否使用。
https://github.com/tizfa/sparkboost
對於Java程式需求的是JavaRDD<MultilabelPoint> 資料格式,而讀取的是RDD<labeledPoint>,轉化為JavaRDD<labeledPoint>。
所以要對於兩種資料格式進行轉換。把label,feature對應起來。
public class ClassifierTask {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("ClassifierTask").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
// 得到常用的Sparkconf和sc, JavaSparkContext to SparkContext
SparkContext sc1 = sc.sc();
String inputFile = "D:\\softs\\spark-1.6.0-bin-hadoop2.6\\data\\mllib\\sample_binary_classification_data.txt";
JavaRDD<String> StringFile = sc.textFile("D:\\softs\\spark-1.6.0-bin-hadoop2.6\\data\\mllib\\sample_libsvm_data.txt");
JavaRDD<LabeledPoint> FileLabeledPoint = MLUtils.loadLibSVMFile(sc1, inputFile).toJavaRDD();
// from RDD to train model,轉換成multilabelpoint
JavaRDD<MultilabelPoint> rdd = FileLabeledPoint.map(Row -> {
int a = (int)Row.label();
SparseVector b = (SparseVector)Row.features();
int docID =0;
int[] labels = {a};
SparseVector feature = b;
return new MultilabelPoint(docID, feature, labels);
});
//train set is 0.8, test set is 0.2,設定權重
double[] weights = {0.8,0.2};
JavaRDD<MultilabelPoint>[] data = rdd.randomSplit(weights);
AdaBoostMHLearner learner = new AdaBoostMHLearner(sc);
//設定分類器的各項引數
learner.setNumIterations(100);
learner.setNumDocumentsPartitions(2);
learner.setNumFeaturesPartitions(2);
learner.setNumLabelsPartitions(2);
BoostClassifier classifier = learner.buildModel(data[0]);
ClassificationResults results = classifier.classifyWithResults(sc, data[1], 1);
// Print results in a StringBuilder.
StringBuilder sb = new StringBuilder();
sb.append("**** Effectiveness\n");
sb.append(results.getCt().toString() + "\n");
sb.append("********\n");
for (int i = 0; i < results.getNumDocs(); i++) {
int docID = results.getDocuments()[i];
int[] labels = results.getLabels()[i];
int[] goldLabels = results.getGoldLabels()[i];
sb.append("DocID: " + docID + ", Labels assigned: " + Arrays.toString(labels) + ", Labels scores: " + Arrays.toString(results.getScores()[i]) + ", Gold labels: " + Arrays.toString(goldLabels)
+ "\n");
}
System.out.print(sb);
}
}