Spark之UDAF
阿新 • • 發佈:2018-11-26
1 import org.apache.spark.sql.{Row, SparkSession} 2 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 3 import org.apache.spark.sql.types._ 4 5 /** 6 * Created by zhen on 2018/11/26. 7 */ 8 object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction{9 //聚合函式輸入資料結構 10 override def inputSchema:StructType = StructType(StructField("input", LongType) :: Nil) 11 12 //快取區資料結構 13 override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) 14 15 //結果資料結構 16 override def dataType : DataType = DoubleType17 18 // 是否具有唯一性 19 override def deterministic : Boolean = true 20 21 //初始化 22 override def initialize(buffer : MutableAggregationBuffer) : Unit = { 23 buffer(0) = 0L 24 buffer(1) = 0L 25 } 26 27 //資料處理 : 必寫,其它方法可選,使用預設 28 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {29 if(input.isNullAt(0)) return 30 buffer(0) = buffer.getLong(0) + input.getLong(0) //求和 31 buffer(1) = buffer.getLong(1) + 1 //計數 32 } 33 34 //合併 35 override def merge(bufferLeft: MutableAggregationBuffer, bufferRight: Row): Unit ={ 36 bufferLeft(0) = bufferLeft.getLong(0) + bufferRight.getLong(0) 37 bufferLeft(1) = bufferLeft.getLong(1) + bufferRight.getLong(1) 38 } 39 40 //計算結果 41 override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1) 42 } 43 object SparkUdaf { 44 def main(args: Array[String]) { 45 val spark = SparkSession 46 .builder() 47 .appName("udaf") 48 .master("local[2]") 49 .getOrCreate() 50 51 spark.read.json("E:/BDS/newsparkml/src/udaf.json").createOrReplaceTempView("user") 52 spark.udf.register("average", AverageUserDefinedAggregateFunction) 53 spark.sql("select count(*) count,average(age) avg_age from user").show() 54 55 } 56 }
結果: