1. 程式人生 > 實用技巧 >Spark(十三)【SparkSQL自定義UDF/UDAF函式】

Spark(十三)【SparkSQL自定義UDF/UDAF函式】

目錄

一.UDF(一進一出)

步驟

① 註冊UDF函式,可以使用匿名函式。

② 在sql查詢的時候使用自定義的UDF。

示例

import org.apache.spark.sql.{DataFrame, SparkSession}

/**
 * @description: UDF一進一出
 * @author: HaoWu
 * @create: 2020年08月09日
 */
object UDF_Test {
  def main(args: Array[String]): Unit = {
    //建立SparkSession
    val session: SparkSession = SparkSession.builder
      .master("local[*]")
      .appName("MyApp")
      .getOrCreate()
    //註冊UDF
    session.udf.register("addHello",(name:String) => "hello:"+name)
    //讀取json格式檔案{"name":"zhangsan","age":20},建立DataFrame
    val df: DataFrame = session.read.json("input/1.txt")
    //建立臨時檢視:person
    df.createOrReplaceTempView("person")
    //查詢的時候使用UDF
    session.sql(
      """select
        |addHello(name),
        |age
        |from person
        |""".stripMargin).show
  }
}

結果

|addHello(name)|age|
+--------------+---+
|hello:zhangsan| 20|
|    hello:lisi| 30|
+--------------+---+

二.UDAF(多近一出)

spark2.X 實現方式

2.X版本:UserDefinedAggregateFunction 無型別或弱型別

步驟

①繼承UserDefinedAggregateFunction,實現其中的方法

②建立函式物件,註冊函式,在sql中使用

    //建立UDFA物件
    val avgDemo1: Avg_UDAF_Demo1 = new Avg_UDAF_Demo1
    //在spark中註冊聚合函式
    spark.udf.register("ageDemo1", avgDemo1)
案例

需求:實現avg()聚合函式的功能,要求結果是Double型別

程式碼實現

①繼承UserDefinedAggregateFunction,實現其中的方法
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StructField, StructType}

/**
 * @description: UDAF(多近一出):求age的平均值
 *              2.X 版本繼承UserDefinedAggregateFunction類,弱型別
 *               非常類似累加器,aggregateByKey運算元的操作,有個ZeroValue,不斷將輸入的值做歸約操作,然後再賦值給ZeroValue
 * @author: HaoWu
 * @create: 2020年08月08日
 */
class Avg_UDAF_Demo1 extends UserDefinedAggregateFunction {
  //聚合函式輸入引數的資料型別,
  override def inputSchema = StructType(StructField("age", LongType) :: Nil)

  //聚合函式緩衝區中值的資料型別(sum,count)
  override def bufferSchema = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)

  //函式返回值的資料型別
  override def dataType = DoubleType

  //穩定性:對於相同的輸入是否一直返回相同的輸出,一般都是true
  override def deterministic = true

  //函式緩衝區初始化,就是ZeroValue清空
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //快取區看做一個數組,將每個元素置空
    //sum
    buffer(0) = 0L
    //count
    buffer(1) = 0L

  }
  //更新緩衝區中的資料->將輸入的值和快取區資料合併
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //input是Row型別,通過getXXX(索引值)取資料
    if (!input.isNullAt(0)) {
      val age = input.getLong(0)
      buffer(0) = buffer.getLong(0) + age
      buffer(1) = buffer.getLong(1) + 1
    }
  }
  //合併緩衝區 (sum1,count1) + (sum2,count2) 合併
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  //計算最終結果
  override def evaluate(buffer: Row) = buffer.getLong(0).toDouble/buffer.getLong(1)
}
②建立函式物件,註冊函式,在sql中使用
/**
 * @description: 實現集合函式avg的功能
 * @author: HaoWu
 * @create: 2020年08月13日
 */
object UDAF_Test {
  def main(args: Array[String]): Unit = {
    
    //建立SparkSession
    val spark: SparkSession = SparkSession.builder
      .master("local[*]")
      .appName("MyApp")
      .getOrCreate()
    //讀取json格式檔案{"name":"zhangsan","age":20}
    val df: DataFrame = spark.read.json("input/1.txt")
    //建立臨時檢視:person
    df.createOrReplaceTempView("person")
    //建立UDFA物件
    val avgDemo1: Avg_UDAF_Demo1 = new Avg_UDAF_Demo1
    //在spark中註冊聚合函式
    spark.udf.register("ageDemo1", avgDemo1)
    //查詢的時候使用UDF
    spark.sql(
      """select
        |ageDemo1(age)
        |from person
        |""".stripMargin).show
  }
}

spark3.X實現方式

3.x版本: 認為2.X繼承UserDefinedAggregateFunction的方式過時,推薦繼承Aggregator ,是強型別

步驟

①繼承Aggregator [-IN, BUF, OUT],宣告泛型,實現其中的方法

    abstract class Aggregator[-IN, BUF, OUT]  
        IN: 輸入的型別      
        BUF:  緩衝區型別     
        OUT: 輸出的型別      

②建立函式物件,註冊函式,在sql中使用

    //建立UDFA物件
    val avgDemo2: Avg_UDAF_Demo2 = new Avg_UDAF_Demo2
    //在spark中註冊聚合函式
    spark.udf.register("myAvg",functions.udaf(avgDemo2))

注意:2.X和3.X的註冊方式不同

案例

需求:實現avg()聚合函式的功能,要求結果是Double型別

程式碼實現

①繼承Aggregator [-IN, BUF, OUT],宣告泛型,實現其中的方法

其中緩衝區資料用樣例類進行封裝。

MyBuffer類

/**
 * 定義MyBuffer樣例類
 * @param sum  組資料sum和
 * @param count  組的資料個數
 */
case class MyBuffer(var sum: Long, var count: Long)

自定義UDAF函式

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.expressions.Aggregator

/**
 * @description: UDAF(多近一出):求age的平均值
 *              3.X Aggregator,強型別
 *               非常類似累加器,aggregateByKey運算元的操作,有個ZeroValue,不斷將輸入的值做歸約操作,然後再賦值給ZeroValue
 * @author: HaoWu
 * @create: 2020年08月08日
 */
class Avg_UDAF_Demo2 extends Aggregator[Long, MyBuffer, Double] {
  //函式緩衝區初始化,就是ZeroValue清空
  override def zero = MyBuffer(0L, 0L)

  //將輸入的值和快取區資料合併
  override def reduce(b: MyBuffer, a: Long) = {
    b.sum = b.sum + a
    b.count = b.count + 1
    b
  }

  //合併緩衝區
  override def merge(b1: MyBuffer, b2: MyBuffer) = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  //計算最終結果
  override def finish(reduction: MyBuffer) = reduction.sum.toDouble / reduction.count

  /* scala中
     常見的資料型別: Encoders.scalaXXX
     自定義的型別:ExpressionEncoder[T]() 返回 Encoder[T]
     樣例類(都是Product型別): Encoders.product[T],返回Produce型別的Encoder!
                                            */
  //快取區的Encoder型別
  override def bufferEncoder = Encoders.product[MyBuffer]

  //輸出結果的Encoder型別
  override def outputEncoder = Encoders.scalaDouble
}
②建立函式物件,註冊函式,在sql中使用
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.{DataFrame, Row, SparkSession, functions}

/**
 * @description: 實現集合函式avg的功能
 * @author: HaoWu
 * @create: 2020年08月13日
 */
object UDAF_Test {
  def main(args: Array[String]): Unit = {

    //建立SparkSession
    val spark: SparkSession = SparkSession.builder
      .master("local[*]")
      .appName("MyApp")
      .getOrCreate()
    //讀取json格式檔案{"name":"zhangsan","age":20}
    val df: DataFrame = spark.read.json("input/1.txt")
    //建立臨時檢視:person
    df.createOrReplaceTempView("person")
    //建立UDFA物件
    val avgDemo2: Avg_UDAF_Demo2 = new Avg_UDAF_Demo2
    //在spark中註冊聚合函式
    spark.udf.register("myAvg",functions.udaf(avgDemo2))
    //查詢的時候使用UDF
    spark.sql(
      """select
        |myAvg(age)
        |from person
        |""".stripMargin).show
  }
}