1. 程式人生 > 實用技巧 >Spark開發-Spark中型別安全UDAF開發示例

Spark開發-Spark中型別安全UDAF開發示例

Spark開發UDAF

 通過對原始碼中的示例程式碼進行實際演練,對各個功能進行了解,以及排除開發中的錯誤
  System.out.println(); 在UDAF中可以用來輔助一些判斷

開發示例程式碼

`
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
public class MeanTypeUDAF  implements Serializable{
/**
 * 輸入資料型別 IN:輸入資料型別
 */
public static class MyEmployee implements Serializable {
    private String name;
    private long salary;
    /**
     * 類中添加了一個無引數的建構函式,問題解決
     * 資料型別 long 和 Long
     */
    public   MyEmployee(){}

    private   MyEmployee(String name, long salary){
        this.name = name;
        this.salary = salary;
    }
    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public long getSalary() {
        return salary;
    }

    public void setSalary(long salary) {
        this.salary = salary;
    }

}

/**
 * 輸出資料型別  OUT:輸出資料型別
 */
public static class AverageBuffer implements Serializable {
    private long sum;
    private long count;
    /**
     * 類中添加了一個無引數的建構函式,問題解決
     * 資料型別 long 和 Long
     */
    public  AverageBuffer(){ }
    private AverageBuffer(long sum, long count){
        this.sum = sum;
        this.count = count;
    }

    public long getSum() {
        return sum;
    }
    public long getCount() {
        return count;
    }
    public void setSum(long sum) {
        this.sum = sum;
    }
    public void setCount(long count) {
        this.count = count;
    }
}

/**
 * abstract class Aggregator[-IN, BUF, OUT] extends Serializable
 *     IN:輸入資料型別
 *    BUF:緩衝區資料型別
 *    OUT:輸出資料型別
 */
public static class MyAverage extends Aggregator<MyEmployee, AverageBuffer , Double>  {
    /**
     * 中間結構的輸入資料結構 Encoder.bean bufferEncoder: Encoder[BUF]
     */
    @Override
    public Encoder<AverageBuffer> bufferEncoder() {
        return Encoders.bean(AverageBuffer.class);
    }

    /**
     * 聚合函式的輸出資料結構 Encoders.DOUBLE()
     */
    @Override
    public Encoder<Double> outputEncoder() {
        return Encoders.DOUBLE();
    }

    /**
     * aggregation 初始化  b + zero = b
     * 初始化緩衝區
     * zero: BUF
     */
    @Override
    public AverageBuffer zero() {
        return new AverageBuffer(0L, 0L);
    }

    /**
     *  給聚合函式傳入一條新資料進行處理
     *  buffer裡面存放著累計的執行結果,input是當前的執行結果
     *  reduce(b: BUF, a: IN): BUF
     */
    @Override
    public AverageBuffer reduce(AverageBuffer buffer, MyEmployee employee) {
        long newSum = buffer.getSum() + employee.getSalary();
        long newCount = buffer.getCount() + 1;
        buffer.setSum(newSum);
        buffer.setCount(newCount);
        return buffer;
    }

    /**
     *  合併聚合函式緩衝區-全域性聚合 merge(b1: BUF, b2: BUF): BUF
     */
    @Override
    public AverageBuffer merge(AverageBuffer b1, AverageBuffer b2) {
        long mergedSum = b1.getSum() + b2.getSum();
        long mergedCount = b1.getCount() + b2.getCount();
        b1.setSum(mergedSum);
        b1.setCount(mergedCount);
        return b1;
    }

    /**
     * 計算最終結果 finish(reduction: BUF): OUT
     */
    @Override
    public Double finish(AverageBuffer reduction) {
        return ((double) reduction.getSum()) / reduction.getCount();
    }
}
public static void main(String[] args) {
    SparkSession spark = SparkSession
            .builder()
            .appName("Java Spark SQL data sources example")
            .config("spark.some.config.option", "some-value")
            .master("local[2]")
            .getOrCreate();
    // Create an instance of a Bean class
    List<MyEmployee> Da = Arrays.asList(
            new MyEmployee("CFF",30L),
            new MyEmployee("CFAF",50L),
            new MyEmployee("ADD",10L)
    );
    Encoder<MyEmployee> personEncoder = Encoders.bean(MyEmployee.class);
    Dataset<MyEmployee> itemsDataset = spark.createDataset( Da, personEncoder);
    itemsDataset.printSchema();
    itemsDataset.show();
    System.out.println(itemsDataset.head().getName());
    System.out.println(itemsDataset.head().getSalary());
    MyAverage myAverage = new MyAverage();
    System.out.println("############");
  // Convert the function to a `TypedColumn` and give it a name
    //使用TypedColumn,目的是為了能在Dataset中使用
    TypedColumn<MyEmployee, Double> averageSalary = myAverage.toColumn().name("average_salary");
    itemsDataset.printSchema();
    Dataset<Double> result = itemsDataset.select(averageSalary);
    result.show();
}
}`

說明

使用UDAF的型別安全的示例,同時也是對Dataset中Bean的資料來源做個簡單的使用

參考

  http://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html