1. 程式人生 > 實用技巧 >Spark的UDF、UDAF、UDTF函式

Spark的UDF、UDAF、UDTF函式

使用者自定義函式

UDF函式

在操作關係型資料庫時,Spark支援大部分常用SQL函式,而有些函式Spark官方並沒有支援,需要根據業務自行建立。這些函式成為使用者自定義函式(user defined function, UDF)。

接受一個引數,返回一個結果。即一進一出的函式。

例項

實現一個UDF,將name列中的使用者名稱稱全部轉換為大寫字母。

spark.udf.register("toUpperCaseUDF", (column : String) => column.toUpperCase)
spark.sql("SELECT toUpperCaseUDF(name), age FROM t_user").show

UDAF函式

使用者自定義聚合函式(user defined aggregation function, UDAF),該型別函式可以接受並處理多個引數(某一列多個行中的值),之後返回一個值,屬於多進一出的函式。

開發者可以通過繼承UserDefinedAggregateFunction抽象類來實現UDAF。繼承該類需要覆寫8個抽象方法。

object AverageUDAF extends UserDefindAggregationFunction {}

def inputSchema : StructType
def bufferSchema : StructType
def dataType : DataType
def deterministic : Boolean
def initialize(buffer : MutableAggregationBuffer) : Unit
def update(buffer : MutableAggregationBuffer, input : Row) : Unit
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit
def evaluate(buffer : Row) : Any

在聚合過程中,用於存放累加資料的容器是MutableAggregationBuffer型別的例項,該型別繼承自Row型別。整個聚合過程就是將原始表的某一列的多個Row例項取出,將對應列中所有待聚合的值累加到緩衝區的Row例項中。

例項

求每個性別的平均年齡

//inputSchema來指定呼叫avgUDAF函式時傳入的引數型別
override def inputSchema: StructType = {
    StructType(
        List(
        StructField("numInput", DoubleType, nullable = true)
        )
    )
}

//bufferSchema設定UDAF在聚合過程中的緩衝區儲存資料的型別,一個引數是年齡總和,一個引數是累加人數
override def bufferSchema: StructType = {
    StructType(
        List(
        StructField("buffer1", DoubleType, nullable = true)
        StructField("buffer2", LongType, nullable = true)
        )
    )
}

//dataType設定UDAF運算結束時返回的資料型別
override def dataType: DataType = DoubleType

//deterministic判斷UDAF可接收的引數型別與返回的結果型別是否一致
override def deteministic: Boolean = true

//initialize初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.0
    buffer(1) = 0L
}

//update用於控制具體的聚合邏輯,通過update方法,將每行參與運算的列累加到聚合緩衝區的Row例項中
//每訪問一行,都會呼叫一次update方法。
override def update(buffer: MutableAggregation, input: Row): Unit = {
    buffer.update(0, buffer.getDouble(0) + input.getDouble(0))
    buffer.update(1, buffer.getLong(1) + 1)
}

//merge用於合併每個分割槽聚合緩衝區的值
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0))
    buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}

//evaluate方法用於對聚合緩衝區的資料進行最後一次運算
override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0) / buffer.getLong(1)
}

在建立完AverageUDAF類後,要註冊UDAF

spark.udf.register("toDouble", (column: Any) => column.toString.toDouble)
spark.udf.register("avgUDAF", AverageUDAF)
spark.sql("SELECT sex, avgUDAF(toDOUble(age)) as avgAge FROM t_user GROUP BY sex").show

UDTF函式

使用者自定義表生成函式。該型別函式可以將一行中的某一列資料展開,變為基於這一列展開後的多行資料。可以通過DataFrame執行flatMap函式來實現“列轉行”。一進多出。

例項

val tableArray = df1.flatMap(row => {
    val listTuple = new scala.collection.mutable.ListBuffer[(String, String)] ()
    val categoryArray = row.getString(1).split(",")
    for(c <- categoryArray) {
        listTuple.append((row.getString(0), c))
    }
    listTuple
}).collect()
val df2 = spark.createDataFrame(tableArray).toDF("movie", "category")
df.show