Spark之hive的UDF自定義函式
阿新 • • 發佈:2019-01-29
1.簡單的
package com.llcc.sparkSql.MyTimeSort
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext
object UDFDemo1 {
def main(args:Array[String]):Unit = {
val conf = new SparkConf().setAppName("aa")
val sc = new SparkContext(conf)
val hiveContext = new HiveContext(sc)
hiveContext.udf .register("strlen",(str:String) => {
if(str != null){
str.length()
}else{
0
}
})
hiveContext.sql("select strlen(category) from xtwy.worker" ).show()
}
}
2. 繼承 UserDefinedAggregateFunction
package com.llcc.sparkSql.MyTimeSort
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types._
object UDFDemo extends UserDefinedAggregateFunction{
/**
* 定義輸入資料的型別,因為這裡我們只有一列資料,但是這裡要求一個集合,所以要加上Nil
* 這裡我們要計算的是hive中的salary欄位
* @return
*/
override def inputSchema: StructType = StructType(
StructField("salary",DoubleType,true)::Nil
)
/**
* 定義快取欄位的名字和資料型別
* @return
*/
override def bufferSchema: StructType = StructType(
StructField("total",DoubleType,true)::
StructField("count",IntegerType,true)::Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
/**
* 對參與的值進行初始化
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0.0)
buffer.update(1,0)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val total = buffer.getDouble(0)
val count = buffer.getInt(1)
val currentSalary = input.getDouble(0)
buffer.update(0,total+currentSalary)
buffer.update(1,count+1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val total1 = buffer1.getDouble(0)
val count1 = buffer1.getInt(1)
val total2 = buffer2.getDouble(0)
val count2 = buffer2.getInt(1)
buffer1.update(0,total1+total2)
buffer1.update(1,count1+count2)
}
override def evaluate(buffer: Row): Any = {
val total = buffer.getDouble(0)
val count = buffer.getInt(1)
total/count
}
def main(args:Array[String]):Unit = {
val conf = new SparkConf().setAppName("aa")
val sc = new SparkContext(conf)
val hiveContext = new HiveContext(sc)
hiveContext.udf.register("salary_avg",UDFDemo)
hiveContext.sql("select salary_avg(salary) from xtwy.worker" ).show()
}
}
原始資料
求薪水的平均值,可以看到是正確的