spark SQL (二) 聚合
阿新 • • 發佈:2019-01-05
聚合內建功能DataFrames提供共同聚合,例如count()
,countDistinct()
,avg()
,max()
,min()
,等。雖然這些功能是專為DataFrames,spark
SQL還擁有型別安全的版本,在其中的一些 scala 和 Java使用強型別資料集的工作。而且,使用者可以預定義的聚合函式,也可以建立自己自定義的聚合函式。
1, 非型別化的使用者定義的聚合函式
使用者必須擴充套件UserDefinedAggregateFunction 抽象類來實現自定義的非型別集合函式。例如,使用者定義的平均值可能如下所示:
2,型別安全的使用者定義的聚合函式import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.expressions.MutableAggregationBuffer import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.types._ object UserDefinedUntypedAggregation { object MyAverage extends UserDefinedAggregateFunction { // 這集合函式的輸入引數的資料型別 def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil) // 在聚合緩衝區中的值的資料型別 def bufferSchema: StructType = { StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) } // 返回值的資料型別 def dataType: DataType = DoubleType // 此函式是否始終在相同的輸入上返回相同的輸出 def deterministic: Boolean = true // 初始化給定的聚合緩衝區。緩衝區本身就是一個“Row”,除了 // 像標準方法(例如,get(),getBoolean())檢索值之外,還提供 // 更新其值的機會。請注意,緩衝區內的陣列和對映仍然是 // 不可變的。 def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L } // 更新給定聚合緩衝區`與來自新的輸入資料buffer``input` def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if (!input.isNullAt(0)) { buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 } } // 合併兩個聚合緩衝劑和儲存更新的緩衝器值回`buffer1` def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } // 計算最終結果 def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1) } def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName("Spark SQL user-defined DataFrames aggregation example") .getOrCreate() // 註冊函式來訪問 spark.udf.register("myAverage", MyAverage) val df = spark.read.json("employees.json") df.createOrReplaceTempView("employees") df.show() // +-------+------+ // | name|salary| // +-------+------+ // |Michael| 3000| // | Andy| 4500| // | Justin| 3500| // | Berta| 4000| // +-------+------+ val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees") result.show() // +--------------+ // |average_salary| // +--------------+ // | 3750.0| // +--------------+ spark.stop() } }
用於強型別資料集的使用者定義聚合圍繞著Aggregator抽象類。例如,型別安全的使用者定義的平均值可能如下所示:
import org.apache.spark.sql.{Encoder, Encoders, SparkSession} import org.apache.spark.sql.expressions.Aggregator object UserDefinedTypedAggregation { case class Employee(name: String, salary: Long) case class Average(var sum: Long, var count: Long) object MyAverage extends Aggregator[Employee, Average, Double] { // 這個聚合的零值。應滿足以下性質:b + zero = b def zero: Average = Average(0L, 0L) //合併兩個值產生一個新的值。為了效能,函式可以修改`buffer` //並返回它,而不是構造一個新的物件 def reduce(buffer: Average, employee: Employee): Average = { buffer.sum += employee.salary buffer.count += 1 buffer } // 合併兩個中間值 def merge(b1: Average, b2: Average): Average = { b1.sum += b2.sum b1.count += b2.count b1 } // 變換還原的輸出 def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count // 指定中間值型別的 def bufferEncoder: Encoder[Average] = Encoders.product // 指定最終輸出值型別的 def outputEncoder: Encoder[Double] = Encoders.scalaDouble } // $example off:typed_custom_aggregation$ def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName("Spark SQL user-defined Datasets aggregation example") .getOrCreate() import spark.implicits._ val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee] ds.show() // +-------+------+ // | name|salary| // +-------+------+ // |Michael| 3000| // | Andy| 4500| // | Justin| 3500| // | Berta| 4000| // +-------+------+ //將函式轉換為“TypedColumn”,並給它一個名稱 val averageSalary = MyAverage.toColumn.name("average_salary") val result = ds.select(averageSalary) result.show() // +--------------+ // |average_salary| // +--------------+ // | 3750.0| // +--------------+ spark.stop() } }