1. 程式人生 > >spark SQL (二) 聚合

spark SQL (二) 聚合

      聚合內建功能DataFrames提供共同聚合,例如count()countDistinct()avg()max()min(),等。雖然這些功能是專為DataFrames,spark SQL還擁有型別安全的版本,在其中的一些 scala 和 Java使用強型別資料集的工作。而且,使用者可以預定義的聚合函式,也可以建立自己自定義的聚合函式。

1, 非型別化的使用者定義的聚合函式

      使用者必須擴充套件UserDefinedAggregateFunction 抽象類來實現自定義的非型別集合函式。例如,使用者定義的平均值可能如下所示:

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._

object UserDefinedUntypedAggregation {

  object MyAverage extends UserDefinedAggregateFunction {
    // 這集合函式的輸入引數的資料型別
    def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
    // 在聚合緩衝區中的值的資料型別
    def bufferSchema: StructType = {
      StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
    }
    // 返回值的資料型別
    def dataType: DataType = DoubleType
    // 此函式是否始終在相同的輸入上返回相同的輸出
    def deterministic: Boolean = true
    // 初始化給定的聚合緩衝區。緩衝區本身就是一個“Row”,除了
    // 像標準方法(例如,get(),getBoolean())檢索值之外,還提供
    // 更新其值的機會。請注意,緩衝區內的陣列和對映仍然是
    // 不可變的。
    def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L
    }
    // 更新給定聚合緩衝區`與來自新的輸入資料buffer``input` 
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      if (!input.isNullAt(0)) {
        buffer(0) = buffer.getLong(0) + input.getLong(0)
        buffer(1) = buffer.getLong(1) + 1
      }
    }
    // 合併兩個聚合緩衝劑和儲存更新的緩衝器值回`buffer1` 
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
      buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }
    // 計算最終結果
    def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("Spark SQL user-defined DataFrames aggregation example")
      .getOrCreate()

    // 註冊函式來訪問
    spark.udf.register("myAverage", MyAverage)

    val df = spark.read.json("employees.json")
    df.createOrReplaceTempView("employees")
    df.show()
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+

    val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
    result.show()
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
    spark.stop()
  }
}
2,型別安全的使用者定義的聚合函式
       用於強型別資料集的使用者定義聚合圍繞著Aggregator抽象類。例如,型別安全的使用者定義的平均值可能如下所示:
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator

object UserDefinedTypedAggregation {

  case class Employee(name: String, salary: Long)
  case class Average(var sum: Long, var count: Long)

  object MyAverage extends Aggregator[Employee, Average, Double] {
    // 這個聚合的零值。應滿足以下性質:b + zero = b 
    def zero: Average = Average(0L, 0L)
    //合併兩個值產生一個新的值。為了效能,函式可以修改`buffer` 
   //並返回它,而不是構造一個新的物件
    def reduce(buffer: Average, employee: Employee): Average = {
      buffer.sum += employee.salary
      buffer.count += 1
      buffer
    }
    // 合併兩個中間值
    def merge(b1: Average, b2: Average): Average = {
      b1.sum += b2.sum
      b1.count += b2.count
      b1
    }
    // 變換還原的輸出
    def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
    // 指定中間值型別的
    def bufferEncoder: Encoder[Average] = Encoders.product
    // 指定最終輸出值型別的
    def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  }
  // $example off:typed_custom_aggregation$

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("Spark SQL user-defined Datasets aggregation example")
      .getOrCreate()

    import spark.implicits._

    val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
    ds.show()
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+

    //將函式轉換為“TypedColumn”,並給它一個名稱
    val averageSalary = MyAverage.toColumn.name("average_salary")
    val result = ds.select(averageSalary)
    result.show()
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
    spark.stop()
  }
}