Spark SQL--UDAF函式
阿新 • • 發佈:2019-01-05
需求:需要通過繼承 UserDefinedAggregateFunction 來實現自定義聚合函式。案例:計算一下員工的平均工資
弱型別聚合函式:
package com.jiangnan.spark import org.apache.spark.SparkConf import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ /** * 弱型別的 * 計算員工的平均薪資 */ class AverageSalaryRuo extends UserDefinedAggregateFunction{ //輸入的資料的格式 override def inputSchema: StructType = StructType(StructField("salary",IntegerType) :: Nil) //每個分割槽中共享的資料變數結構 override def bufferSchema: StructType = StructType(StructField("sum",LongType) :: StructField("count",IntegerType):: Nil) //輸出的資料的型別 override def dataType: DataType = DoubleType //表示如果有相同的輸入是否會存在相同的輸出,是:true override def deterministic: Boolean = true //初始化的每個分割槽共享變數 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0 } //每一個分割槽的每一條資料聚合的時候進行buffer的更新 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //將buffer中的薪資總和的資料進行更新,原資料加上新輸入的資料,buffer就類似於resultSet buffer(0) = buffer.getLong(0) + input.getInt(0) //每新增一個薪資,就將員工的個數加1 buffer(1) = buffer.getInt(1)+1 } //將每個分割槽的輸出合併 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1) } //獲取最終的結果 override def evaluate(buffer: Row): Any = { //計算平均薪資並返回 buffer.getLong(0).toDouble/buffer.getInt(1) } } object AverageSalaryRuo extends App{ val conf = new SparkConf().setAppName("udaf").setMaster("local[3]") val spark = SparkSession.builder().config(conf).getOrCreate() val data = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json") data.createOrReplaceTempView("employee") //註冊自定義聚合函式 spark.udf.register("avgSalary",new AverageSalaryRuo) spark.sql("select avgSalary(salary) from employee").show() spark.stop() }
強型別聚合函式:
package com.jiangnan.spark import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{Encoder, Encoders, SparkSession} /** * 弱型別的 * 計算員工的平均薪資 */ //對於強型別來說,無非就是藉助於樣例類 case class Employee(name:String,salary:Long) case class Average(var sum:Long,var count:Int) class AverageSalaryQiang extends Aggregator[Employee,Average,Double]{ //初始化方法 override def zero: Average = Average(0L,0) //一個分割槽內的聚合呼叫,類似於update方法 override def reduce(b: Average, a: Employee): Average = { b.sum = b.sum + a.salary b.count = b.count + 1 b } override def merge(b1: Average, b2: Average): Average = { b1.sum = b1.sum + b2.sum b1.count = b1.count + b2.count b1 } //最終的計算結果 override def finish(reduction: Average): Double = { reduction.sum.toDouble /reduction.count } //對buffer編碼 override def bufferEncoder: Encoder[Average] = Encoders.product //對out編碼 override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } object AverageSalaryQiang extends App{ val conf = new SparkConf().setAppName("udaf").setMaster("local[3]") val spark = SparkSession.builder().config(conf).getOrCreate() import spark.implicits._ val employee = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json").as[Employee] employee.show() employee.createOrReplaceTempView("employee") //註冊自定義函式 val aaa = new AverageSalaryQiang().toColumn.name("aaaa") spark.sql("select * from employee").show() //spark.sql("select aaaa(salary) from employee").show() employee.select(aaa).show() spark.stop() }