1. 程式人生 > >Spark SQL--UDAF函式

Spark SQL--UDAF函式

需求:需要通過繼承 UserDefinedAggregateFunction 來實現自定義聚合函式。案例:計算一下員工的平均工資

弱型別聚合函式:

package com.jiangnan.spark
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
  * 弱型別的
  * 計算員工的平均薪資
  */
class AverageSalaryRuo extends UserDefinedAggregateFunction{
  //輸入的資料的格式
  override def inputSchema: StructType = StructType(StructField("salary",IntegerType) :: Nil)
  //每個分割槽中共享的資料變數結構
  override def bufferSchema: StructType = StructType(StructField("sum",LongType) :: StructField("count",IntegerType):: Nil)
  //輸出的資料的型別
  override def dataType: DataType = DoubleType
  //表示如果有相同的輸入是否會存在相同的輸出,是:true
  override def deterministic: Boolean = true
  //初始化的每個分割槽共享變數
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0
  }
  //每一個分割槽的每一條資料聚合的時候進行buffer的更新
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //將buffer中的薪資總和的資料進行更新,原資料加上新輸入的資料,buffer就類似於resultSet
    buffer(0) = buffer.getLong(0) + input.getInt(0)
    //每新增一個薪資,就將員工的個數加1
    buffer(1) = buffer.getInt(1)+1
  }
  //將每個分割槽的輸出合併
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
  }
  //獲取最終的結果
  override def evaluate(buffer: Row): Any = {
    //計算平均薪資並返回
    buffer.getLong(0).toDouble/buffer.getInt(1)
  }
}
object AverageSalaryRuo extends App{
  val conf = new SparkConf().setAppName("udaf").setMaster("local[3]")
  val spark = SparkSession.builder().config(conf).getOrCreate()
  val data = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json")
  data.createOrReplaceTempView("employee")
  //註冊自定義聚合函式
  spark.udf.register("avgSalary",new AverageSalaryRuo)
  spark.sql("select avgSalary(salary) from employee").show()
  spark.stop()
}

強型別聚合函式:

package com.jiangnan.spark
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
/**
  * 弱型別的
  * 計算員工的平均薪資
  */
//對於強型別來說,無非就是藉助於樣例類
case class Employee(name:String,salary:Long)
case class Average(var sum:Long,var count:Int)
class AverageSalaryQiang extends Aggregator[Employee,Average,Double]{
  //初始化方法
  override def zero: Average = Average(0L,0)
  //一個分割槽內的聚合呼叫,類似於update方法
  override def reduce(b: Average, a: Employee): Average = {
    b.sum = b.sum + a.salary
    b.count = b.count + 1
    b
  }
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }
  //最終的計算結果
  override def finish(reduction: Average): Double = {
    reduction.sum.toDouble /reduction.count
  }
  //對buffer編碼
  override def bufferEncoder: Encoder[Average] = Encoders.product
  //對out編碼
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object AverageSalaryQiang extends App{
  val conf = new SparkConf().setAppName("udaf").setMaster("local[3]")
  val spark = SparkSession.builder().config(conf).getOrCreate()
  import  spark.implicits._
  val employee = spark.read.json("C:\\Users\\zhang\\Desktop\\employees.json").as[Employee]
  employee.show()
  employee.createOrReplaceTempView("employee")
  //註冊自定義函式
  val aaa = new AverageSalaryQiang().toColumn.name("aaaa")
  spark.sql("select * from employee").show()
  //spark.sql("select aaaa(salary) from employee").show()
  employee.select(aaa).show()
  spark.stop()
}