Spark 之 UDF 函式
阿新 • • 發佈:2019-02-01
package cn.com.systex
import scala.reflect.runtime.universe
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.functions.callUDF
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import java.sql.Timestamp
import java.sql.Date
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.DateType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructField
/**
* DateTime: 2015年12月25日 上午10:41:42
*
*/
//定義一個日期範圍類
case classDateRange(startDate: Timestamp , endDate: Timestamp) {
defin(targetDate: Date): Boolean = {
targetDate.before(endDate) && targetDate.after(startDate)
}
override deftoString(): String = {
startDate.toLocaleString() + " " + endDate.toLocaleString();
}
}
//定義UDAF函式,按年聚合後比較,需要實現UserDefinedAggregateFunction中定義的方法
classYearOnYearCompare(current: DateRange) extendsUserDefinedAggregateFunction{
val previous: DateRange = DateRange(subtractOneYear(current.startDate), subtractOneYear(current.endDate))
println(current)
println(previous)
//UDAF與DataFrame列有關的輸入樣式,StructField的名字並沒有特別要求,完全可以認為是兩個內部結構的列名佔位符。
//至於UDAF具體要操作DataFrame的哪個列,取決於呼叫者,但前提是資料型別必須符合事先的設定,如這裡的DoubleType與DateType型別
definputSchema: StructType = {
StructType(StructField("metric", DoubleType) :: StructField("timeCategory", DateType) :: Nil)
}
//定義儲存聚合運算時產生的中間資料結果的Schema
defbufferSchema: StructType = {
StructType(StructField("sumOfCurrent", DoubleType) :: StructField("sumOfPrevious", DoubleType) :: Nil)
}
//標明瞭UDAF函式的返回值型別
defdataType: org.apache.spark.sql.types.DataType = DoubleType
//用以標記針對給定的一組輸入,UDAF是否總是生成相同的結果
defdeterministic: Boolean = true
//對聚合運算中間結果的初始化
definitialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0.0)
}
//第二個引數input: Row對應的並非DataFrame的行,而是被inputSchema投影了的行。以本例而言,每一個input就應該只有兩個Field的值
defupdate(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (current.in(input.getAs[Date](1))) {
buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
if (previous.in(input.getAs[Date](1))) {
buffer(1) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
}
//負責合併兩個聚合運算的buffer,再將其儲存到MutableAggregationBuffer中
defmerge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}
//完成對聚合Buffer值的運算,得到最後的結果
defevaluate(buffer: Row): Any = {
if (buffer.getDouble(1) == 0.0) {
0.0
} else {
(buffer.getDouble(0) - buffer.getDouble(1)) / buffer.getDouble(1) * 100
}
}
private defsubtractOneYear(date: Timestamp): Timestamp = {
val prev = new Timestamp(date.getTime)
prev.setYear(prev.getYear - 1)
prev
}
}
/**
* Spark 1.5 DataFrame API Highlights: Date/Time/String Handling, Time Intervals, and UDAFs
* https://databricks.com/blog/2015/09/16/spark-1-5-dataframe-api-highlights-datetimestring-handling-time-intervals-and-udafs.html
*/
objectSimpleDemo{
defmain(args: Array[String]): Unit = {
val dir = "D:/Program/spark/examples/src/main/resources/";
val sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("sqltest"))
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._
//用$符號來包裹一個字串表示一個Column,定義在SQLContext物件implicits中的一個隱式轉換
//DataFrame的API可以接收Column物件,UDF的定義不能直接定義為Scala函式,而是要用定義在org.apache.spark.sql.functions中的udf方法來接收一個函式。
//這種方式無需register
//如果需要在函式中傳遞一個變數,則需要org.apache.spark.sql.functions中的lit函式來幫助
//建立DataFrame
val df = sqlContext.createDataFrame(Seq(
(1, "張三峰", "廣東 廣州 天河", 24),
(2, "李四", "廣東 廣州 東山", 36),
(3, "王五", "廣東 廣州 越秀", 48),
(4, "趙六", "廣東 廣州 海珠", 29))).toDF("id", "name", "addr", "age")
//定義函式
defsplitAddrFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
}
val longLength = udf((str: String, length: Int) => str.length > length)
val len = udf((str: String) => str.length)
//使用函式
val df2 = df.withColumn("addr-ex", callUDF(splitAddrFunc, new ArrayType(StringType, true), df("addr")))
val df3 = df2.withColumn("name-len", len($"name")).filter(longLength($"name", lit(2)))
println("列印DF Schema及資料處理結果")
df.printSchema()
df3.printSchema()
df3.foreach { println }
//SQL模型
//定義普通的scala函式,然後在SQLContext中進行註冊,就可以在SQL中使用了。
defslen(str: String): Int = str.length
defslengthLongerThan(str: String, length: Int): Boolean = str.length > length
sqlContext.udf.register("len", slen _)
sqlContext.udf.register("longLength", slengthLongerThan _)
df.registerTempTable("user")
println("列印SQL語句執行結果")
sqlContext.sql("select name,len(name) from user where longLength(name,2)").foreach(println)
println("列印資料過濾結果")
df.filter("longLength(name,2)").foreach(println)
//如果定義UDAF(User Defined Aggregate Function)
//Spark為所有的UDAF定義了一個父類UserDefinedAggregateFunction。要繼承這個類,需要實現父類的幾個抽象方法
val salesDF = sqlContext.createDataFrame(Seq(
(1, "Widget Co", 1000.00, 0.00, "AZ", "2014-01-02"),
(2, "Acme Widgets", 2000.00, 500.00, "CA", "2014-02-01"),
(3, "Widgetry", 1000.00, 200.00, "CA", "2015-01-11"),
(4, "Widgets R Us", 5000.00, 0.0, "CA", "2015-02-19"),
(5, "Ye Olde Widgete", 4200.00, 0.0, "MA", "2015-02-18"))).toDF("id", "name", "sales", "discount", "state", "saleDate")
salesDF.registerTempTable("sales")
val current = DateRange(Timestamp.valueOf("2015-01-01 00:00:00"), Timestamp.valueOf("2015-12-31 00:00:00"))
//在使用上,除了需要對UDAF進行例項化之外,與普通的UDF使用沒有任何區別
val yearOnYear = new YearOnYearCompare(current)
sqlContext.udf.register("yearOnYear", yearOnYear)
val dataFrame = sqlContext.sql("select yearOnYear(sales, saleDate) as yearOnYear from sales")
salesDF.printSchema()
dataFrame.show()
}
}