SparkSQL之自定義函式UDF和UDAF
阿新 • • 發佈:2018-11-11
SparkSQL中有兩種自定函式,在我們使用自帶的函式時無法滿足自己的需求時,可以使用自定義函式,SparkSQL中有兩種自定義函式,一種是UDF,另一種是UDAF,和Hive 很類似,但是hive中還有UDTF,一進多出,但是sparkSQL中沒有,這是因為spark中用 flatMap這個函式,可以實現和udtf相同的功能
UDF函式是針對的是一進一出
UDAF針對的是多進一出
udf很簡單,只需要註冊一下,然後寫一個函式,就可以在sql查詢中使用了
df1.createTempView("user")
//註冊
spark.udf.register ("lengthStr",(str:String)=>str.length)//自定義函式
//直接在sql中就可以使用啦
val df2 = spark.sql("select lengthStr(name) from user")
udaf相對來說比較複雜一點,需要繼承一個 UserDefinedAggregateFunction類,在重寫其中的方法,自定義函式求平均值,詳細的步驟在下面的程式碼中
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession, types}
object UDAFavg {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("avg").master("local").getOrCreate()
val sc = spark.sparkContext
val sqlContext = spark.sqlContext
val files: RDD[String] = sc.textFile("D:\\read\\teacher.txt")
val rowRDD: RDD[Row] = files.map(row => {
val split = row.split(" ")
Row(split(0), split(1),split(2).toLong)
})
/* rowRDD.foreach(row =>{
println(row.getString(0)+" "+row.getString(1)+row.get(2))
})*/
val structType = StructType(List(StructField("subject",StringType,true),StructField("tname",StringType,true),
StructField("age",LongType,true)))
val df1: DataFrame = spark.createDataFrame(rowRDD,structType)
df1.createTempView("teacher")
//註冊函式, 自定義一個函式,實現求平均數
spark.udf.register("TeacherAvg",new UDAFavg)
//df1.show()
spark.sql("select subject,TeacherAvg(age) as avgAGE from teacher group by subject ").show()
}
}
//自定義UDAF函式
class UDAFavg extends UserDefinedAggregateFunction{
//輸入資料型別,求平均值,所以資料型別是LongType(StructType中的型別)
override def inputSchema: StructType = {
StructType(List(StructField("age",LongType,true)))}
//中間結果的型別,這裡定義了兩個中間的型別,因為在求平均值時,首先一個存總的和,一個計算個數,最後的結果是兩者相除
override def bufferSchema: StructType = {
StructType(List(StructField("age",LongType),StructField("count",LongType)))}
//輸出返回型別
override def dataType: DataType = {LongType}
//是否資料同一性,一般都是true
override def deterministic: Boolean = true
//初始化定義兩個中間值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//型別要和上面定義的位置相對應
buffer(0) = 0L //初始化 總和
buffer(1) = 0L // 個數
}
//進行計算,
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //input是每次的輸入Row型別
buffer(1) = buffer.getAs[Long](1)+ 1 //個數 每次加1
buffer(0) = buffer.getAs[Long](0) + input.getLong(0)
// 把每個傳的值進行累加
}
//有可能有多個分割槽,多個task ,總後把進行合併
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(1) = buffer1.getAs[Long](1)+ buffer2.getAs[Long](1)//多臺機器中的count的值進行相加
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getLong(0)
}
//返回的最終結果
override def evaluate(buffer: Row): Any = {
buffer.getAs[Long](0) / buffer.getAs[Long](1)
}
}