1. 程式人生 > >Spark 之 UDF 函式

Spark 之 UDF 函式

package cn.com.systex import scala.reflect.runtime.universe import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.Experimental import org.apache.spark.sql.functions.callUDF import org.apache.spark.sql.functions.lit import
org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.StringType import java.sql.Timestamp import java.sql.Date import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row import org.apache.spark.sql.types.DateType import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructField /** * DateTime: 2015年12月25日 上午10:41:42 * */ //定義一個日期範圍類 case classDateRange(startDate: Timestamp
, endDate: Timestamp
)
{ defin(targetDate: Date): Boolean = { targetDate.before(endDate) && targetDate.after(startDate) } override deftoString(): String = { startDate.toLocaleString() + " " + endDate.toLocaleString(); } } //定義UDAF函式,按年聚合後比較,需要實現UserDefinedAggregateFunction中定義的方法 classYearOnYearCompare(current: DateRange) extendsUserDefinedAggregateFunction{ val previous: DateRange = DateRange(subtractOneYear(current.startDate), subtractOneYear(current.endDate)) println(current) println(previous) //UDAF與DataFrame列有關的輸入樣式,StructField的名字並沒有特別要求,完全可以認為是兩個內部結構的列名佔位符。 //至於UDAF具體要操作DataFrame的哪個列,取決於呼叫者,但前提是資料型別必須符合事先的設定,如這裡的DoubleType與DateType型別 definputSchema: StructType = { StructType(StructField("metric", DoubleType) :: StructField("timeCategory", DateType) :: Nil) } //定義儲存聚合運算時產生的中間資料結果的Schema defbufferSchema: StructType = { StructType(StructField("sumOfCurrent", DoubleType) :: StructField("sumOfPrevious", DoubleType) :: Nil) } //標明瞭UDAF函式的返回值型別 defdataType: org.apache.spark.sql.types.DataType = DoubleType //用以標記針對給定的一組輸入,UDAF是否總是生成相同的結果 defdeterministic: Boolean = true //對聚合運算中間結果的初始化 definitialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0, 0.0) buffer.update(1, 0.0) } //第二個引數input: Row對應的並非DataFrame的行,而是被inputSchema投影了的行。以本例而言,每一個input就應該只有兩個Field的值 defupdate(buffer: MutableAggregationBuffer, input: Row): Unit = { if (current.in(input.getAs[Date](1))) { buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0) } if (previous.in(input.getAs[Date](1))) { buffer(1) = buffer.getAs[Double](0) + input.getAs[Double](0) } } //負責合併兩個聚合運算的buffer,再將其儲存到MutableAggregationBuffer中 defmerge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0) buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1) } //完成對聚合Buffer值的運算,得到最後的結果 defevaluate(buffer: Row): Any = { if (buffer.getDouble(1) == 0.0) { 0.0 } else { (buffer.getDouble(0) - buffer.getDouble(1)) / buffer.getDouble(1) * 100 } } private defsubtractOneYear(date: Timestamp): Timestamp = { val prev = new Timestamp(date.getTime) prev.setYear(prev.getYear - 1) prev } } /** * Spark 1.5 DataFrame API Highlights: Date/Time/String Handling, Time Intervals, and UDAFs * https://databricks.com/blog/2015/09/16/spark-1-5-dataframe-api-highlights-datetimestring-handling-time-intervals-and-udafs.html */ objectSimpleDemo{ defmain(args: Array[String]): Unit = { val dir = "D:/Program/spark/examples/src/main/resources/"; val sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("sqltest")) val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.implicits._ //用$符號來包裹一個字串表示一個Column,定義在SQLContext物件implicits中的一個隱式轉換 //DataFrame的API可以接收Column物件,UDF的定義不能直接定義為Scala函式,而是要用定義在org.apache.spark.sql.functions中的udf方法來接收一個函式。 //這種方式無需register //如果需要在函式中傳遞一個變數,則需要org.apache.spark.sql.functions中的lit函式來幫助 //建立DataFrame val df = sqlContext.createDataFrame(Seq( (1, "張三峰", "廣東 廣州 天河", 24), (2, "李四", "廣東 廣州 東山", 36), (3, "王五", "廣東 廣州 越秀", 48), (4, "趙六", "廣東 廣州 海珠", 29))).toDF("id", "name", "addr", "age") //定義函式 defsplitAddrFunc: String => Seq[String] = { _.toLowerCase.split("\\s") } val longLength = udf((str: String, length: Int) => str.length > length) val len = udf((str: String) => str.length) //使用函式 val df2 = df.withColumn("addr-ex", callUDF(splitAddrFunc, new ArrayType(StringType, true), df("addr"))) val df3 = df2.withColumn("name-len", len($"name")).filter(longLength($"name", lit(2))) println("列印DF Schema及資料處理結果") df.printSchema() df3.printSchema() df3.foreach { println } //SQL模型 //定義普通的scala函式,然後在SQLContext中進行註冊,就可以在SQL中使用了。 defslen(str: String): Int = str.length defslengthLongerThan(str: String, length: Int): Boolean = str.length > length sqlContext.udf.register("len", slen _) sqlContext.udf.register("longLength", slengthLongerThan _) df.registerTempTable("user") println("列印SQL語句執行結果") sqlContext.sql("select name,len(name) from user where longLength(name,2)").foreach(println) println("列印資料過濾結果") df.filter("longLength(name,2)").foreach(println) //如果定義UDAF(User Defined Aggregate Function) //Spark為所有的UDAF定義了一個父類UserDefinedAggregateFunction。要繼承這個類,需要實現父類的幾個抽象方法 val salesDF = sqlContext.createDataFrame(Seq( (1, "Widget Co", 1000.00, 0.00, "AZ", "2014-01-02"), (2, "Acme Widgets", 2000.00, 500.00, "CA", "2014-02-01"), (3, "Widgetry", 1000.00, 200.00, "CA", "2015-01-11"), (4, "Widgets R Us", 5000.00, 0.0, "CA", "2015-02-19"), (5, "Ye Olde Widgete", 4200.00, 0.0, "MA", "2015-02-18"))).toDF("id", "name", "sales", "discount", "state", "saleDate") salesDF.registerTempTable("sales") val current = DateRange(Timestamp.valueOf("2015-01-01 00:00:00"), Timestamp.valueOf("2015-12-31 00:00:00")) //在使用上,除了需要對UDAF進行例項化之外,與普通的UDF使用沒有任何區別 val yearOnYear = new YearOnYearCompare(current) sqlContext.udf.register("yearOnYear", yearOnYear) val dataFrame = sqlContext.sql("select yearOnYear(sales, saleDate) as yearOnYear from sales") salesDF.printSchema() dataFrame.show() } }