Spark(十三)SparkSQL的自定義函式UDF與開窗函式
阿新 • • 發佈:2021-08-03
一 自定義函式UDF
在Spark中,也支援Hive中的自定義函式。自定義函式大致可以分為三種:
- UDF(User-Defined-Function),即最基本的自定義函式,類似to_char,to_date等
- UDAF(User- Defined Aggregation Funcation),使用者自定義聚合函式,類似在group by之後使用的sum,avg等
- UDTF(User-Defined Table-Generating Functions),使用者自定義生成函式,有點像stream裡面的flatMap
自定義一個UDF函式需要繼承UserDefinedAggregateFunction類,並實現其中的8個方法
示例
import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} object GetDistinctCityUDF extends UserDefinedAggregateFunction{ /** * 輸入的資料型別 * */ override def inputSchema: StructType = StructType( StructField("status",StringType,true) :: Nil ) /** * 快取欄位型別 * */ override def bufferSchema: StructType = { StructType( Array( StructField("buffer_city_info",StringType,true) ) ) } /** * 輸出結果型別 * */ override def dataType: DataType = StringType /** * 輸入型別和輸出型別是否一致 * */ override def deterministic: Boolean = true /** * 對輔助欄位進行初始化 * */ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0,"") } /** *修改輔助欄位的值 * */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //獲取最後一次的值 var last_str = buffer.getString(0) //獲取當前的值 val current_str = input.getString(0) //判斷最後一次的值是否包含當前的值 if(!last_str.contains(current_str)){ //判斷是否是第一個值,是的話走if賦值,不是的話走else追加 if(last_str.equals("")){ last_str = current_str }else{ last_str += "," + current_str } } buffer.update(0,last_str) } /** *對分割槽結果進行合併 * buffer1是機器hadoop1上的結果 * buffer2是機器Hadoop2上的結果 * */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { var buf1 = buffer1.getString(0) val buf2 = buffer2.getString(0) //將buf2裡面存在的資料而buf1裡面沒有的資料追加到buf1 //buf2的資料按照,進行切分 for(s <- buf2.split(",")){ if(!buf1.contains(s)){ if(buf1.equals("")){ buf1 = s }else{ buf1 += s } } } buffer1.update(0,buf1) } /** * 最終的計算結果 * */ override def evaluate(buffer: Row): Any = { buffer.getString(0) } }
註冊自定義的UDF函式為臨時函式
def main(args: Array[String]): Unit = { /** * 第一步 建立程式入口 */ val conf = new SparkConf().setAppName("AralHotProductSpark") val sc = new SparkContext(conf) val hiveContext = new HiveContext(sc) //註冊成為臨時函式 hiveContext.udf.register("get_distinct_city",GetDistinctCityUDF) //註冊成為臨時函式 hiveContext.udf.register("get_product_status",(str:String) =>{ var status = 0 for(s <- str.split(",")){ if(s.contains("product_status")){ status = s.split(":")(1).toInt } } }) }
二開窗函式
row_number() 開窗函式是按照某個欄位分組,然後取另一欄位的前幾個的值,相當於分組取topN
如果SQL語句裡面使用到了開窗函式,那麼這個SQL語句必須使用HiveContext來執行,HiveContext預設情況下在本地無法建立。
開窗函式格式:
row_number() over (partitin by XXX order by XXX)
java:
SparkConf conf = new SparkConf(); conf.setAppName("windowfun"); JavaSparkContext sc = new JavaSparkContext(conf); HiveContext hiveContext = new HiveContext(sc); hiveContext.sql("use spark"); hiveContext.sql("drop table if exists sales"); hiveContext.sql("create table if not exists sales (riqi string,leibie string,jine Int) " + "row format delimited fields terminated by '\t'"); hiveContext.sql("load data local inpath '/root/test/sales' into table sales"); /** * 開窗函式格式: * 【 rou_number() over (partitin by XXX order by XXX) 】 */ DataFrame result = hiveContext.sql("select riqi,leibie,jine " + "from (" + "select riqi,leibie,jine," + "row_number() over (partition by leibie order by jine desc) rank " + "from sales) t " + "where t.rank<=3"); result.show(); sc.stop();
scala:
val conf = new SparkConf() conf.setAppName("windowfun") val sc = new SparkContext(conf) val hiveContext = new HiveContext(sc) hiveContext.sql("use spark"); hiveContext.sql("drop table if exists sales"); hiveContext.sql("create table if not exists sales (riqi string,leibie string,jine Int) " + "row format delimited fields terminated by '\t'"); hiveContext.sql("load data local inpath '/root/test/sales' into table sales"); /** * 開窗函式格式: * 【 rou_number() over (partitin by XXX order by XXX) 】 */ val result = hiveContext.sql("select riqi,leibie,jine " + "from (" + "select riqi,leibie,jine," + "row_number() over (partition by leibie order by jine desc) rank " + "from sales) t " + "where t.rank<=3"); result.show(); sc.stop()轉自:https://www.cnblogs.com/frankdeng/p/9301712.html