1. 程式人生 > >Spark使用UDF函式之WordCount實現

Spark使用UDF函式之WordCount實現

       使用者定義函式(User-defined functions, UDFs)是大多數 SQL 環境的關鍵特性,用於擴充套件系統的內建功能。 UDF允許開發人員通過抽象其低階語言實現來在更高階語言(如SQL)中啟用新功能。 Apache Spark 也不例外,並且提供了用於將 UDF 與 Spark SQL工作流整合的各種選項。

      本文通過自定義UDF實現WordCount案例:

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object UDF {

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder().appName("UDF").master("local[2]").getOrCreate()

    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("WARN")

    val bigData = Array("Spark", "Spark", "Hadoop", "Spark", "Hadoop", "Spark", "Spark", "Hadoop", "Spark", "Hadoop")

    val bigDataRDD: RDD[String] = sc.parallelize(bigData)

    val bigDataRDDRow: RDD[Row] = bigDataRDD.map(item => Row(item))
    val structType: StructType = StructType(Array(
      StructField("word", StringType, true)
    ))
    val bigDataDF: DataFrame = spark.createDataFrame(bigDataRDDRow,structType)

    bigDataDF.createOrReplaceTempView("bigDataTable")

    spark.udf.register("computeLength",(input:String) => input.length)
    //直接在SQL語句中使用UDF,就像使用SQL內建函式一樣
    spark.sql("select word,computeLength(word) as length from bigDataTable").show()

    spark.udf.register("wordCount", new MyUDAF)
    spark.sql("select word,computeLength(word) as length, wordCount(word) as count from bigDataTable group by word").show()

    sc.stop()
    spark.stop()

  }

}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


class MyUDAF extends UserDefinedAggregateFunction{
  //該方法指定具體輸入資料型別
  override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))

  //在進行聚合操作的時候所要處理的資料的結果的型別
  override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))

  //返回的資料型別
  override def dataType: DataType = IntegerType

  //確保結果一致性
  override def deterministic: Boolean = true

  //在Aggregate之前每組資料的初始化結果
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }

  //在進行聚合的時候,每當有新的值進來,對分組後的聚合如何進行計算
  //本地的聚合,相當於Hadood MapReduce模型中的Combiner
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  //最後在分散式節點進行Local Reduce完成後需要進行全域性級別的Merge操作
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }

  override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}