1. 程式人生 > >SparkSQL之自定義函式UDF和UDAF

SparkSQL之自定義函式UDF和UDAF

SparkSQL中有兩種自定函式,在我們使用自帶的函式時無法滿足自己的需求時,可以使用自定義函式,SparkSQL中有兩種自定義函式,一種是UDF,另一種是UDAF,和Hive 很類似,但是hive中還有UDTF,一進多出,但是sparkSQL中沒有,這是因為spark中用 flatMap這個函式,可以實現和udtf相同的功能
UDF函式是針對的是一進一出
UDAF針對的是多進一出

udf很簡單,只需要註冊一下,然後寫一個函式,就可以在sql查詢中使用了

    df1.createTempView("user")
    //註冊
    spark.udf.register
("lengthStr",(str:String)=>str.length)//自定義函式 //直接在sql中就可以使用啦 val df2 = spark.sql("select lengthStr(name) from user")

udaf相對來說比較複雜一點,需要繼承一個 UserDefinedAggregateFunction類,在重寫其中的方法,自定義函式求平均值,詳細的步驟在下面的程式碼中

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SparkSession, types} object UDAFavg { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().appName("avg").master("local").getOrCreate() val sc = spark.sparkContext val sqlContext =
spark.sqlContext val files: RDD[String] = sc.textFile("D:\\read\\teacher.txt") val rowRDD: RDD[Row] = files.map(row => { val split = row.split(" ") Row(split(0), split(1),split(2).toLong) }) /* rowRDD.foreach(row =>{ println(row.getString(0)+" "+row.getString(1)+row.get(2)) })*/ val structType = StructType(List(StructField("subject",StringType,true),StructField("tname",StringType,true), StructField("age",LongType,true))) val df1: DataFrame = spark.createDataFrame(rowRDD,structType) df1.createTempView("teacher") //註冊函式, 自定義一個函式,實現求平均數 spark.udf.register("TeacherAvg",new UDAFavg) //df1.show() spark.sql("select subject,TeacherAvg(age) as avgAGE from teacher group by subject ").show() } } //自定義UDAF函式 class UDAFavg extends UserDefinedAggregateFunction{ //輸入資料型別,求平均值,所以資料型別是LongType(StructType中的型別) override def inputSchema: StructType = { StructType(List(StructField("age",LongType,true)))} //中間結果的型別,這裡定義了兩個中間的型別,因為在求平均值時,首先一個存總的和,一個計算個數,最後的結果是兩者相除 override def bufferSchema: StructType = { StructType(List(StructField("age",LongType),StructField("count",LongType)))} //輸出返回型別 override def dataType: DataType = {LongType} //是否資料同一性,一般都是true override def deterministic: Boolean = true //初始化定義兩個中間值 override def initialize(buffer: MutableAggregationBuffer): Unit = { //型別要和上面定義的位置相對應 buffer(0) = 0L //初始化 總和 buffer(1) = 0L // 個數 } //進行計算, override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //input是每次的輸入Row型別 buffer(1) = buffer.getAs[Long](1)+ 1 //個數 每次加1 buffer(0) = buffer.getAs[Long](0) + input.getLong(0) // 把每個傳的值進行累加 } //有可能有多個分割槽,多個task ,總後把進行合併 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(1) = buffer1.getAs[Long](1)+ buffer2.getAs[Long](1)//多臺機器中的count的值進行相加 buffer1(0) = buffer1.getAs[Long](0) + buffer2.getLong(0) } //返回的最終結果 override def evaluate(buffer: Row): Any = { buffer.getAs[Long](0) / buffer.getAs[Long](1) } }