1. 程式人生 > 實用技巧 >Flink基礎(三十九):FLINK SQL(十五) 函式(四)自定義函式(二)

Flink基礎(三十九):FLINK SQL(十五) 函式(四)自定義函式(二)

1 標量函式

自定義標量函式可以把 0 到多個標量值對映成 1 個標量值,資料型別裡列出的任何資料型別都可作為求值方法的引數和返回值型別。

想要實現自定義標量函式,你需要擴充套件org.apache.flink.table.functions裡面的ScalarFunction並且實現一個或者多個求值方法。標量函式的行為取決於你寫的求值方法。求值方法必須是public的,而且名字必須是eval

下面的例子展示瞭如何實現一個求雜湊值的函式並在查詢裡呼叫它,詳情可參考開發指南

import org.apache.flink.table.annotation.InputGroup
import org.apache.flink.table.api._
import org.apache.flink.table.functions.ScalarFunction class HashFunction extends ScalarFunction { // 接受任意型別輸入,返回 INT 型輸出 def eval(@DataTypeHint(inputGroup = InputGroup.ANY) o: AnyRef): Int { return o.hashCode(); } } val env = TableEnvironment.create(...) // 在 Table API 裡不經註冊直接“內聯”呼叫函式 env.from("MyTable").select(call(classOf[HashFunction], $"myField"))
// 註冊函式 env.createTemporarySystemFunction("HashFunction", classOf[HashFunction]) // 在 Table API 裡呼叫註冊好的函式 env.from("MyTable").select(call("HashFunction", $"myField")) // 在 SQL 裡呼叫註冊好的函式 env.sqlQuery("SELECT HashFunction(myField) FROM MyTable")

如果你打算使用 Python 實現或呼叫標量函式,詳情可參考Python 標量函式

2 表值函式

跟自定義標量函式一樣,自定義表值函式的輸入引數也可以是 0 到多個標量。但是跟標量函式只能返回一個值不同的是,它可以返回任意多行。返回的每一行可以包含 1 到多列,如果輸出行只包含 1 列,會省略結構化資訊並生成標量值,這個標量值在執行階段會隱式地包裝進行裡。

要定義一個表值函式,你需要擴充套件org.apache.flink.table.functions下的TableFunction,可以通過實現多個名為eval的方法對求值方法進行過載。像其他函式一樣,輸入和輸出型別也可以通過反射自動提取出來。表值函式返回的表的型別取決於TableFunction類的泛型引數T,不同於標量函式,表值函式的求值方法本身不包含返回型別,而是通過collect(T)方法來發送要輸出的行。

在 Table API 中,表值函式是通過.joinLateral(...)或者.leftOuterJoinLateral(...)來使用的。joinLateral運算元會把外表(運算元左側的表)的每一行跟跟表值函式返回的所有行(位於運算元右側)進行 (cross)join。leftOuterJoinLateral運算元也是把外表(運算元左側的表)的每一行跟表值函式返回的所有行(位於運算元右側)進行(cross)join,並且如果表值函式返回 0 行也會保留外表的這一行。

在 SQL 裡面用JOIN或者 以ON TRUE為條件的LEFT JOIN來配合LATERAL TABLE(<TableFunction>)的使用。

下面的例子展示瞭如何實現一個分隔函式並在查詢裡呼叫它,詳情可參考開發指南

import org.apache.flink.table.annotation.DataTypeHint
import org.apache.flink.table.annotation.FunctionHint
import org.apache.flink.table.api._
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.types.Row

@FunctionHint(output = new DataTypeHint("ROW<word STRING, length INT>"))
class SplitFunction extends TableFunction[Row] {

  def eval(str: String): Unit = {
    // use collect(...) to emit a row
    str.split(" ").foreach(s => collect(Row.of(s, Int.box(s.length))))
  }
}

val env = TableEnvironment.create(...)

// 在 Table API 裡不經註冊直接“內聯”呼叫函式
env
  .from("MyTable")
  .joinLateral(call(classOf[SplitFunction], $"myField")
  .select($"myField", $"word", $"length")
env
  .from("MyTable")
  .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField"))
  .select($"myField", $"word", $"length")

// 在 Table API 裡重新命名函式欄位
env
  .from("MyTable")
  .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField").as("newWord", "newLength"))
  .select($"myField", $"newWord", $"newLength")

// 註冊函式
env.createTemporarySystemFunction("SplitFunction", classOf[SplitFunction])

// 在 Table API 裡呼叫註冊好的函式
env
  .from("MyTable")
  .joinLateral(call("SplitFunction", $"myField"))
  .select($"myField", $"word", $"length")
env
  .from("MyTable")
  .leftOuterJoinLateral(call("SplitFunction", $"myField"))
  .select($"myField", $"word", $"length")

// 在 SQL 裡呼叫註冊好的函式
env.sqlQuery(
  "SELECT myField, word, length " +
  "FROM MyTable, LATERAL TABLE(SplitFunction(myField))");
env.sqlQuery(
  "SELECT myField, word, length " +
  "FROM MyTable " +
  "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) ON TRUE")

// 在 SQL 裡重新命名函式欄位
env.sqlQuery(
  "SELECT myField, newWord, newLength " +
  "FROM MyTable " +
  "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) AS T(newWord, newLength) ON TRUE")

如果你打算使用 Scala,不要把表值函式宣告為 Scalaobject,Scalaobject是單例物件,將導致併發問題。

如果你打算使用 Python 實現或呼叫表值函式,詳情可參考Python 表值函式

3 聚合函式

自定義聚合函式(UDAGG)是把一個表(一行或者多行,每行可以有一列或者多列)聚合成一個標量值

上面的圖片展示了一個聚合的例子。假設你有一個關於飲料的表。表裡面有三個欄位,分別是idnameprice,表裡有 5 行資料。假設你需要找到所有飲料裡最貴的飲料的價格,即執行一個max()聚合。你需要遍歷所有 5 行資料,而結果就只有一個數值。

自定義聚合函式是通過擴充套件AggregateFunction來實現的。AggregateFunction的工作過程如下。首先,它需要一個accumulator,它是一個數據結構,儲存了聚合的中間結果。通過呼叫AggregateFunctioncreateAccumulator()方法建立一個空的 accumulator。接下來,對於每一行資料,會呼叫accumulate()方法來更新 accumulator。當所有的資料都處理完了之後,通過呼叫getValue方法來計算和返回最終的結果。

下面幾個方法是每個AggregateFunction必須要實現的:

  • createAccumulator()
  • accumulate()
  • getValue()

Flink 的型別推導在遇到複雜型別的時候可能會推匯出錯誤的結果,比如那些非基本型別和普通的 POJO 型別的複雜型別。所以跟ScalarFunctionTableFunction一樣,AggregateFunction也提供了AggregateFunction#getResultType()AggregateFunction#getAccumulatorType()來分別指定返回值型別和 accumulator 的型別,兩個函式的返回值型別也都是TypeInformation

除了上面的方法,還有幾個方法可以選擇實現。這些方法有些可以讓查詢更加高效,而有些是在某些特定場景下必須要實現的。例如,如果聚合函式用在會話視窗(當兩個會話視窗合併的時候需要 merge 他們的 accumulator)的話,merge()方法就是必須要實現的。

AggregateFunction的以下方法在某些場景下是必須實現的:

  • retract()在 boundedOVER視窗中是必須實現的。
  • merge()在許多批式聚合和會話視窗聚合中是必須實現的。
  • resetAccumulator()在許多批式聚合中是必須實現的。

AggregateFunction的所有方法都必須是public的,不能是static的,而且名字必須跟上面寫的一樣。createAccumulatorgetValuegetResultType以及getAccumulatorType這幾個函式是在抽象類AggregateFunction中定義的,而其他函式都是約定的方法。如果要定義一個聚合函式,你需要擴充套件org.apache.flink.table.functions.AggregateFunction,並且實現一個(或者多個)accumulate方法。accumulate方法可以過載,每個方法的引數型別不同,並且支援變長引數。

AggregateFunction的所有方法的詳細文件如下。

/**
  * Base class for user-defined aggregates and table aggregates.
  *
  * @tparam T   the type of the aggregation result.
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  */
abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction {

  /**
    * Creates and init the Accumulator for this (table)aggregate function.
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC // MANDATORY

  /**
    * Returns the TypeInformation of the (table)aggregate function's result.
    *
    * @return The TypeInformation of the (table)aggregate function's result or null if the result
    *         type should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null // PRE-DEFINED

  /**
    * Returns the TypeInformation of the (table)aggregate function's accumulator.
    *
    * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
}

/**
  * Base class for aggregation functions.
  *
  * @tparam T   the type of the aggregation result
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  *             AggregateFunction represents its state using accumulator, thereby the state of the
  *             AggregateFunction must be put into the accumulator.
  */
abstract class AggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] {

  /**
    * Processes the input values and update the provided accumulator instance. The method
    * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
    * requires at least one accumulate() method.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY

  /**
    * Retracts the input values from the accumulator instance. The current design assumes the
    * inputs are the values that have been previously accumulated. The method retract can be
    * overloaded with different custom types and arguments. This function must be implemented for
    * datastream bounded over aggregate.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL

  /**
    * Merges a group of accumulator instances into one accumulator instance. This function must be
    * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which will keep the merged aggregate results. It should
    *                     be noted that the accumulator may contain the previous aggregated
    *                     results. Therefore user should not replace or clean this instance in the
    *                     custom merge method.
    * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that will be
    *                     merged.
    */
  def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL

  /**
    * Called every time when an aggregation result should be materialized.
    * The returned value could be either an early and incomplete result
    * (periodically emitted as data arrive) or the final result of the
    * aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @return the aggregation result
    */
  def getValue(accumulator: ACC): T // MANDATORY

  /**
    * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
    * dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which needs to be reset
    */
  def resetAccumulator(accumulator: ACC): Unit // OPTIONAL

  /**
    * Returns true if this AggregateFunction can only be applied in an OVER window.
    *
    * @return true if the AggregateFunction requires an OVER window, false otherwise.
    */
  def requiresOver: Boolean = false // PRE-DEFINED
}
View Code

下面的例子展示瞭如何:

  • 定義一個聚合函式來計算某一列的加權平均,
  • TableEnvironment中註冊函式,
  • 在查詢中使用函式。

為了計算加權平均值,accumulator 需要儲存加權總和以及資料的條數。在我們的例子裡,我們定義了一個類WeightedAvgAccum來作為 accumulator。Flink 的 checkpoint 機制會自動儲存 accumulator,在失敗時進行恢復,以此來保證精確一次的語義。

我們的WeightedAvg(聚合函式)的accumulate方法有三個輸入引數。第一個是WeightedAvgAccumaccumulator,另外兩個是使用者自定義的輸入:輸入的值ivalue和 輸入的權重iweight。儘管retract()merge()resetAccumulator()這幾個方法在大多數聚合型別中都不是必須實現的,我們也在樣例中提供了他們的實現。請注意我們在 Scala 樣例中也是用的是 Java 的基礎型別,並且定義了getResultType()getAccumulatorType(),因為 Flink 的型別推導對於 Scala 的型別推導做的不是很好。

import java.lang.{Long => JLong, Integer => JInteger}
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.AggregateFunction

/**
 * Accumulator for WeightedAvg.
 */
class WeightedAvgAccum extends JTuple1[JLong, JInteger] {
  sum = 0L
  count = 0
}

/**
 * Weighted Average user-defined aggregate function.
 */
class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] {

  override def createAccumulator(): WeightedAvgAccum = {
    new WeightedAvgAccum
  }

  override def getValue(acc: WeightedAvgAccum): JLong = {
    if (acc.count == 0) {
        null
    } else {
        acc.sum / acc.count
    }
  }

  def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum += iValue * iWeight
    acc.count += iWeight
  }

  def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum -= iValue * iWeight
    acc.count -= iWeight
  }

  def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {
    val iter = it.iterator()
    while (iter.hasNext) {
      val a = iter.next()
      acc.count += a.count
      acc.sum += a.sum
    }
  }

  def resetAccumulator(acc: WeightedAvgAccum): Unit = {
    acc.count = 0
    acc.sum = 0L
  }

  override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = {
    new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT)
  }

  override def getResultType: TypeInformation[JLong] = Types.LONG
}

// 註冊函式
val tEnv: StreamTableEnvironment = ???
tEnv.registerFunction("wAvg", new WeightedAvg())

// 使用函式
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")