大資料學習之路90-sparkSQL自定義聚合函式UDAF
阿新 • • 發佈:2018-11-09
什麼是UDAF?就是輸入N行得到一個結果,屬於聚合類的。
接下來我們就寫一個求幾何平均數的一個自定義聚合函式的例子
我們從開頭寫起,先來看看需要進行計算的數如何產生:
package com.test.SparkSQL import java.lang import org.apache.spark.sql.{Dataset, SparkSession} object UDAFDemo { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("UDAFDemo") .master("local[*]") .getOrCreate() val ds: Dataset[lang.Long] = spark.range(1,10) ds.show() } }
生成結果:
接下來我們使用自定義聚合函式計算幾何平均數:
package com.test.SparkSQL import java.lang import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Dataset, Row, SparkSession, types} object UDAFDemo { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("UDAFDemo") .master("local[*]") .getOrCreate() val ds: Dataset[lang.Long] = spark.range(1,10) //ds.show() ds.createTempView("v_num") val gm = new GeometriMean spark.udf.register("gm",gm) spark.sql("select gm(id) as gm from v_num").show() } } class GeometriMean extends UserDefinedAggregateFunction{ //定義輸入資料的型別 override def inputSchema: StructType = StructType(List(StructField("value",DoubleType))) //定義儲存聚合運算時產生的中間資料結果的型別 override def bufferSchema: StructType = StructType( List( StructField("count",LongType), StructField("product",DoubleType) ) ) //表名了UDAF函式的返回值型別 override def dataType: DataType = DoubleType //用以標記針對給定的一組輸入,UDAF是否總是生成相同的結果 override def deterministic: Boolean = true //對聚合運算中間結果的初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 1.0 } //每處理一條資料都要執行update,相當於區域性計算 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getAs[Long](0)+1 buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0) } //負責合併兩個聚合運算的buffer,再將其儲存到MutableAggregationBuffer,相當於全域性計算 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0) buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1) } //完成對聚合Buffer值的運算,得到最後的結果 override def evaluate(buffer: Row): Any = { math.pow(buffer.getDouble(1),1.toDouble/buffer.getLong(0)) } }
執行結果: