1. 程式人生 > >spark實現決策樹

spark實現決策樹

我們以iris資料集(iris)為例進行分析。iris以鳶尾花的特徵作為資料來源,資料集包含150個數據集,分為3類,每類50個數據,每個資料包含4個屬性,是在資料探勘、資料分類中非常常用的測試集、訓練集。決策樹可以用於分類和迴歸,接下來我們將在程式碼中分別進行介紹。

  1. 匯入需要的包:
from pyspark.ml.linalg import Vector,Vectors
from pyspark.sql import Row
from pyspark.ml import Pipeline
from pyspark.ml.feature import IndexToString,StringIndexer,
VectorIndexer
  1. 讀取資料,簡要分析:
    讀取文字檔案,第一個map把每行的資料用“,”隔開,比如在我們的資料集中,每行被分成了5部分,前4部分是鳶尾花的4個特徵,最後一部分是鳶尾花的分類;我們這裡把特徵儲存在Vector中,建立一個Iris模式的RDD,然後轉化成dataframe;然後把剛剛得到的資料註冊成一個表iris,註冊成這個表之後,我們就可以通過sql語句進行資料查詢;選出我們需要的資料後,我們可以把結果打印出來檢視一下資料。
def f(x):
    rel = {}
    rel['features'] = Vectors.dense(float(x[0]),float
(x[1]),float(x[2]),float(x[3])) rel['label'] = str(x[4]) return rel data = spark.sparkContext.textFile("file:///usr/local/spark/iris.txt").map(lambda line: line.split(',')).map(lambda p: Row(**f(p))).toDF() data.createOrReplaceTempView("iris") df = spark.sql("select * from iris") rel = df.
rdd.map(lambda t : str(t[1])+":"+str(t[0])).collect() for item in rel: print(item) Iris-setosa:[5.1,3.5,1.4,0.2] Iris-setosa:[4.9,3.0,1.4,0.2] Iris-setosa:[4.7,3.2,1.3,0.2] Iris-setosa:[4.6,3.1,1.5,0.2] Iris-setosa:[5.0,3.6,1.4,0.2] Iris-setosa:[5.4,3.9,1.7,0.4] Iris-setosa:[4.6,3.4,1.4,0.3] ..... Iris-versicolor:[5.7,2.8,4.1,1.3] Iris-virginica:[6.3,3.3,6.0,2.5] Iris-virginica:[5.8,2.7,5.1,1.9] Iris-virginica:[7.1,3.0,5.9,2.1] Iris-virginica:[6.3,2.9,5.6,1.8] Iris-virginica:[6.5,3.0,5.8,2.2] Iris-virginica:[7.6,3.0,6.6,2.1] Iris-virginica:[4.9,2.5,4.5,1.7] Iris-virginica:[7.3,2.9,6.3,1.8] Iris-virginica:[6.7,2.5,5.8,1.8] Iris-virginica:[7.2,3.6,6.1,2.5] Iris-virginica:[6.5,3.2,5.1,2.0] Iris-virginica:[6.4,2.7,5.3,1.9] Iris-virginica:[6.8,3.0,5.5,2.1] Iris-virginica:[5.7,2.5,5.0,2.0] Iris-virginica:[5.8,2.8,5.1,2.4] Iris-virginica:[6.4,3.2,5.3,2.3] Iris-virginica:[6.5,3.0,5.5,1.8] Iris-virginica:[7.7,3.8,6.7,2.2] Iris-virginica:[7.7,2.6,6.9,2.3] Iris-virginica:[6.0,2.2,5.0,1.5] Iris-virginica:[6.9,3.2,5.7,2.3] Iris-virginica:[5.6,2.8,4.9,2.0] Iris-virginica:[7.7,2.8,6.7,2.0] ... ...
  1. 進一步處理特徵和標籤,以及資料分組:
//分別獲取標籤列和特徵列,進行索引,並進行了重新命名。
labelIndexer = StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
 
featureIndexer = VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(df)
 
//這裡我們設定一個labelConverter,目的是把預測的類別重新轉化成字元型的。
labelConverter = IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
//接下來,我們把資料集隨機分成訓練集和測試集,其中訓練集佔70%。
trainingData, testData = data.randomSplit([0.7, 0.3])
  1. 構建決策樹分類模型:
//匯入所需要的包
from pyspark.ml.classification import DecisionTreeClassificationModel,DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
//訓練決策樹模型,這裡我們可以通過setter的方法來設定決策樹的引數,也可以用ParamMap來設定(具體的可以檢視spark mllib的官網)。具體的可以設定的引數可以通過explainParams()來獲取。
dtClassifier = DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
//在pipeline中進行設定
pipelinedClassifier = Pipeline().setStages([labelIndexer, featureIndexer, dtClassifier, labelConverter])
//訓練決策樹模型
modelClassifier = pipelinedClassifier.fit(trainingData)
//進行預測
predictionsClassifier = modelClassifier.transform(testData)
//檢視部分預測的結果
predictionsClassifier.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.3,3.0,1.1,0.1]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.1,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.0,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|    Iris-setosa|    Iris-setosa|[5.0,3.5,1.3,0.3]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.4,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.8,1.9,0.4]|
|Iris-versicolor|Iris-versicolor|[5.2,2.7,3.9,1.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.3,0.4]|
|Iris-versicolor|Iris-versicolor|[5.7,2.8,4.5,1.3]|
|Iris-versicolor|Iris-versicolor|[5.8,2.7,4.1,1.0]|
|    Iris-setosa|    Iris-setosa|[5.8,4.0,1.2,0.2]|
| Iris-virginica|Iris-versicolor|[5.9,3.2,4.8,1.8]|
|Iris-versicolor|Iris-versicolor|[6.1,2.9,4.7,1.4]|
+---------------+---------------+-----------------+
only showing top 20 rows
 
  1. 評估決策樹分類模型:
evaluatorClassifier = MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
 
accuracy = evaluatorClassifier.evaluate(predictionsClassifier)
 
print("Test Error = " + str(1.0 - accuracy))
Test Error = 0.05882352941176472
 
treeModelClassifier = modelClassifier.stages[2]
 
print("Learned classification tree model:\n" + str(treeModelClassifier.toDebugString))
Learned classification tree model:
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_4e57b26beacfd363271a) of depth 3 with 7 nodes
  If (feature 2 <= 1.9)
   Predict: 2.0
  Else (feature 2 > 1.9)
   If (feature 3 <= 1.6)
    If (feature 2 <= 4.9)
     Predict: 0.0
    Else (feature 2 > 4.9)
     Predict: 1.0
   Else (feature 3 > 1.6)
    Predict: 1.0
 

​ 從上述結果可以看到模型的預測準確率為 0.94以及訓練的決策樹模型結構。

  1. 構建決策樹迴歸模型:
//匯入所需要的包
from pyspark.ml.regression import DecisionTreeRegressionModel,DecisionTreeRegressor
from pyspark.ml.evaluation import RegressionEvaluator
//訓練決策樹模型
dtRegressor = DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
//在pipeline中進行設定
pipelineRegressor = Pipeline().setStages([labelIndexer, featureIndexer, dtRegressor, labelConverter])
//訓練決策樹模型
modelRegressor = pipelineRegressor.fit(trainingData)
//進行預測
predictionsRegressor = modelRegressor.transform(testData)
//檢視部分預測結果
predictionsRegressor.select("predictedLabel", "label", "features").show(20)
 
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.3,3.0,1.1,0.1]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.1,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.0,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|    Iris-setosa|    Iris-setosa|[5.0,3.5,1.3,0.3]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.4,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.8,1.9,0.4]|
|Iris-versicolor|Iris-versicolor|[5.2,2.7,3.9,1.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9<