1. 程式人生 > 其它 >大資料演算法設計模式(1) - topN spark實現

大資料演算法設計模式(1) - topN spark實現

topN演算法,spark實現

package com.kangaroo.studio.algorithms.topn;


import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;

import java.io.Serializable;
import java.util.*;


public class TopNSpark implements Serializable {

    private JavaSparkContext jsc;
    Broadcast<Integer> topNum;
    private String inputPath;

    /*
    *   建構函式
    *   1. 初始化JavaSparkContext
    *   2. 初始化廣播變數topN個數, 可以被所有partition共享
    *   3. 初始化輸入路徑
    * */
    public TopNSpark(Integer Num, String path) {
        jsc = new JavaSparkContext();
        topNum = jsc.broadcast(Num);
        inputPath = path;
    }

    /*
    *   程式入口函式
    * */
    public void run() {
        /*
        *   讀入inputPath中的資料
        * */
        JavaRDD<String> lines = jsc.textFile(inputPath, 1);

        /*
        *   將rdd規約到9個分割槽
        * */
        JavaRDD<String> rdd = lines.coalesce(9);

        /*
        *   將輸入轉化為kv格式
        *   key是規約的主鍵, value是排序參考的個數
        *   注: 這裡的key並不唯一, 即相同的key可能有多條記錄, 所以下面我們規約key成唯一鍵
        *   輸入:line, 輸出:kv
        * */
        JavaPairRDD<String, Integer> kv = rdd.mapToPair(new PairFunction<String, String, Integer>() {
            public Tuple2<String, Integer> call(String s) throws Exception {
                String[] tokens = s.split(",");
                return new Tuple2<String, Integer>(tokens[0], Integer.parseInt(tokens[1]));
            }
        });

        /*
        *   規約主鍵成為唯一鍵
        *   輸入:kv, 輸出:kv
        * */
        JavaPairRDD<String, Integer> uniqueKeys = kv.reduceByKey(new Function2<Integer, Integer, Integer>() {
            public Integer call(Integer i1, Integer i2) throws Exception {
                return i1 + i2;
            }
        });

        /*
        *   計算各個分割槽的topN
        *   這裡通過廣播變數拿到了topN具體個數, 每個分割槽都保留topN, 所有分割槽總個數: partitionNum * topN
        *   輸入:kv, 輸出:SortMap, 長度topN
        * */
        JavaRDD<SortedMap<Integer, String>> partitions = uniqueKeys.mapPartitions(new FlatMapFunction<Iterator<Tuple2<String,Integer>>, SortedMap<Integer, String>>() {
            public Iterable<SortedMap<Integer, String>> call(Iterator<Tuple2<String, Integer>> iter) throws Exception {
                final int N = topNum.getValue();
                SortedMap<Integer, String> topN = new TreeMap<Integer, String>();
                while (iter.hasNext()) {
                    Tuple2<String, Integer> tuple = iter.next();
                    topN.put(tuple._2, tuple._1);

                    if (topN.size() > N) {
                        topN.remove(topN.firstKey());
                    }
                }
                return Collections.singletonList(topN);
            }
        });

        /*
        *   規約所有分割槽的topN SortMap, 得到最終的SortMap, 長度topN
        *   reduce過後, 資料已經到了本地快取, 這是最後結果
        *   輸入: SortMap, 長度topN, 當然有partitionNum個, 輸出:SortMap, 長度topN
        * */
        SortedMap<Integer, String> finalTopN = partitions.reduce(new Function2<SortedMap<Integer, String>, SortedMap<Integer, String>, SortedMap<Integer, String>>() {
            public SortedMap<Integer, String> call(SortedMap<Integer, String> m1, SortedMap<Integer, String> m2) throws Exception {
                final int N = topNum.getValue();
                SortedMap<Integer, String> topN = new TreeMap<Integer, String>();
                for (Map.Entry<Integer, String> entry : m1.entrySet()) {
                    topN.put(entry.getKey(), entry.getValue());
                    if (topN.size() > N) {
                        topN.remove(topN.firstKey());
                    }
                }
                for (Map.Entry<Integer, String> entry : m2.entrySet()) {
                    topN.put(entry.getKey(), entry.getValue());
                    if (topN.size() > N) {
                        topN.remove(topN.firstKey());
                    }
                }
                return topN;
            }
        });

        /*
        *   將本地快取的最終結果打印出來
        * */
        for (Map.Entry<Integer, String> entry : finalTopN.entrySet()) {
            System.out.println(entry.getKey() + " -- " + entry.getValue());
        }
    }


    public static void main(String[] args) {
        /*
        *   topN個數:topN
        *   輸入資料路徑:inputPath
        * */
        Integer topN = Integer.parseInt(args[0]);
        String inputPath = args[1];
        TopNSpark topNSpark = new TopNSpark(topN, inputPath);
        topNSpark.run();
    }
}