Spark 2.x 決策樹 示例程式碼-IRIS資料集
阿新 • • 發佈:2019-02-13
資料集下載
程式碼
package Iris;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache .spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark .ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField ;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import util.InitSparkUtil;
import java.util.HashMap;
import java.util.Map;
/**
* Created by xy on 2018/4/20.
*/
public class IrisDT {
public static final String[] iris = new String[]{"Iris_setosa", "Iris_versicolor", "Iris_virginica"};
public static void irisDT() {
//1、構造SparkSession
InitSparkUtil initSparkUtil = new InitSparkUtil();
SparkSession spark = initSparkUtil.getSparkSession("irisDT");
//2、載入資料
Dataset<Row> data = spark.read().csv("E:\\idea工程\\data\\iris.csv");
data = data.toDF("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Species");
JavaRDD<String> dataRdd = data.toJavaRDD().map(x -> x.toString().replace("[", "").replace("]", ""));
//3、把資料轉為Row的形式
JavaRDD<Row> irisRowRDD = dataRdd.map(x -> x.split(",")).map(x -> {
double[] ds = new double[x.length - 1];
for (int i = 0; i < x.length - 1; i++) {
ds[i] = Double.parseDouble(x[i]);
}
return RowFactory.create(Vectors.dense(ds), x[x.length - 1].replace("-", "_"));
});
//4、定義StructType
StructType schema = new StructType(new StructField[]{new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("label", DataTypes.StringType, false, Metadata.empty())});
//5、分層抽樣
JavaRDD<Row> trainDataRDD = stratifiedSample(irisRowRDD);
JavaRDD<Row> testDataRDD = irisRowRDD.subtract(trainDataRDD);
Dataset<Row> trainData = spark.createDataFrame(trainDataRDD, schema);
Dataset<Row> testData = spark.createDataFrame(testDataRDD, schema);
Dataset<Row> fullData = trainData.union(testData);
fullData.cache();
trainData.show(150);
testData.show(150);
fullData.show(2000);
/**
* 6、fit方法都會產生一個Model。把特徵列進行索引,即列的不同值小於4的,就轉為Int型離散變數,不然就認為是連續值。
* InputCol裡面的值要和StructType裡面的對應上。
*/
VectorIndexerModel featureIndexer = new VectorIndexer().setInputCol("features").setMaxCategories(4).setOutputCol("indexedFeatures").fit(fullData);
Dataset<Row> featureIndexData = featureIndexer.transform(fullData);
featureIndexData.show(200);
/**
* 7、StringIndexer:把類別這一列,由String轉為標籤,便於計算,即變為int型的離散變數,從0開始。
* 索引的順序是頻率,頻率最大的為0.
*/
StringIndexerModel labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(fullData);
Dataset<Row> labelIndexData = labelIndexer.transform(fullData);
labelIndexData.show(200);
//8、把預測的類別重新轉為String型
IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels());
/**
* 9、建立決策樹。setMaxDepth:設定最大深度;setMinInfoGain:最小資訊增益;
* setMinInstancesPerNode:某個節點的樣本數小於該值,就不再被分叉。
* setImpurity:使用什麼樣的增益演算法,gini是Gini不純度,entropy是資訊熵。
*/
DecisionTreeClassifier dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxDepth(20).setMinInfoGain(0.00001).setMinInstancesPerNode(1).setImpurity("gini");
//建立Pipeline
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndexer, featureIndexer, dtClassifier, labelConverter});
/**
*Pipeline的2個方法:
* fit:傳入DF進行訓練併產生模型,意思就是對資料進行一些統計學習規律,最後得到一個模型。
* transform:將一個DF轉為另一個DF,對資料進行操作,可以對資料進行轉換,進行預測等。
*/
//訓練
PipelineModel modelClassifier = pipeline.fit(trainData);
//預測
Dataset<Row> predictionClassifier = modelClassifier.transform(testData);
predictionClassifier.select("predictedLabel", "label", "features").show(200);
//評估
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy");
double accuracy = evaluator.evaluate(predictionClassifier);
System.out.println(accuracy);
//模型結構
Transformer dtModel = modelClassifier.stages()[2];
DecisionTreeClassificationModel treeClassModel = (DecisionTreeClassificationModel) dtModel;
String treeModelStruct = treeClassModel.toDebugString();
System.out.println(treeModelStruct);
fullData.unpersist();
}
protected static JavaRDD<Row> stratifiedSample(JavaRDD<Row> irisRowRDD) {
JavaPairRDD<String, Row> pariRDD = irisRowRDD.mapToPair(x -> new Tuple2<>(x.getString(1), x));
Map<String, Double> fractions = new HashMap<>();
for (int i = 0; i < iris.length; i++) {
fractions.put(iris[i], 0.8);
}
JavaRDD<Row> trainRDD = pariRDD.sampleByKeyExact(false, fractions, 0).map(x -> x._2);
return trainRDD;
}
public static void main(String[] args) {
irisDT();
}
}
package util;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
/**
* 初始化spark類
*/
public class InitSparkUtil {
private JavaSparkContext sc;
public SparkSession getSparkSession(String appname) {
SparkConf conf = new SparkConf().setMaster("local");
SparkSession spark = SparkSession.builder().appName(appname).config(conf).getOrCreate();
return spark;
}
public JavaSparkContext getSc(String appname) {
SparkConf conf = new SparkConf().setMaster("local").setAppName(appname);
this.sc = new JavaSparkContext(conf);
return sc;
}
}