1. 程式人生 > >spark原始碼解讀3之RDD中top原始碼解讀

spark原始碼解讀3之RDD中top原始碼解讀

spark原始碼解讀系列環境:spark-2.0.1 (20161103github下載版)

1.理解

輸出讀取中常用到topK演算法,RDD也提供了top方法。特別是RDD過大時,要慎用RDD的collect方法,建議使用take和top方法。如果要有序,可以使用top方法。

1.1 定義

  def top(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope {
    takeOrdered(num)(ord.reverse)
  }

num為要取的額個數,ord為隱式轉換,可以取最高的topK,也可以放入逆序取最低的topK,top方法呼叫的是takeOrdered方法。

1.2 原始碼理解

1.2.1 takeOrdered

top呼叫的是takeOrdered,top呼叫的是takeOrdered的原始碼為:

  /**
   * Returns the first k (smallest) elements from this RDD as defined by the specified
   * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]].
   * For example:
   * {{{
   *   sc.parallelize(Seq(10, 4, 2, 12, 3)).takeOrdered(1)
   *   // returns Array(2)
   *
   *   sc.parallelize(Seq(2, 3, 4, 5, 6)).takeOrdered(2)
   *   // returns Array(2, 3)
   * }}}
   *
   * @note this method should only be used if the resulting array is expected to be small, as
   * all the data is loaded into the driver's memory.
   *
   * @param num k, the number of elements to return
   * @param ord the implicit ordering for T
   * @return an array of top elements
   */

  def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope {
    if (num == 0) {
      Array.empty
    } else {
      val mapRDDs = mapPartitions { items =>
        // Priority keeps the largest elements, so let's reverse the ordering.
        val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
        queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
        Iterator.single(queue)
      }
      if (mapRDDs.partitions.length == 0) {
        Array.empty
      } else {
        mapRDDs.reduce { (queue1, queue2) =>
          queue1 ++= queue2
          queue1
        }.toArray.sorted(ord)
      }
    }
  }

理解:
1.2.1.1 takeOrdered會使用有界的優先佇列BoundedPriorityQueue,儲存返回的k個元素。
1.2.1.2 mapPartitions是對每一個partition進行操作,對每個partition元素集合items,呼叫org.apache.spark.util.collection.takeOrdered取num個數,然後生成由若干個partition組成的mapRDDs,每個partition為大小為k的有界優先佇列queue
1.2.1.3 然後進行reduce操作,reduce是將兩個queue進行++操作,即將兩個長度為k的queue1和queue2合併成一個長為1的queue。然後進行toArray和sort(ord)。++方法為BoundedPriorityQueue類中的方法,++會呼叫+=方法進行操作:

  override def ++=(xs: TraversableOnce[A]): this.type = {
    xs.foreach { this += _ }
    this
  }

  override def +=(elem: A): this.type = {
    if (size < maxSize) {
      underlying.offer(elem)
    } else {
      maybeReplaceLowest(elem)
    }
    this
  }

具體可以檢視BoundedPriorityQueue的方法

sorted(ord)方法呼叫java.util.Arrays.sort,後面1.2.3.1 會講到

1.2.2 org.apache.spark.util.collection.takeOrdered

org.apache.spark.util.collection.takeOrdered的takeOrdered是呼叫的com.google.common.collect.{Ordering => GuavaOrdering}
的方法,並且重寫了compare方法,主要是Ordering預設的是從小到大,而top預設是取最大的num個元素

val ordering = new GuavaOrdering[T] {
      override def compare(l: T, r: T): Int = ord.compare(l, r)
    }

然後再呼叫 ordering的leastOf方法,中間有java和scala的iterator容器的相互轉換:

ordering.leastOf(input.asJava, num).iterator.asScala    

1.2.3 com.google.common.collect.Ordering的leastOf方法

包的引入方式:在maven的pom檔案中加入

 <dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>14.0.1</version>
    <scope>provided</scope>
  </dependency>

原始碼:

 /**
   * Returns the {@code k} least elements from the given iterator according to
   * this ordering, in order from least to greatest.  If there are fewer than
   * {@code k} elements present, all will be included.
   *
   * <p>The implementation does not necessarily use a <i>stable</i> sorting
   * algorithm; when multiple elements are equivalent, it is undefined which
   * will come first.
   *
   * @return an immutable {@code RandomAccess} list of the {@code k} least
   *     elements in ascending order
   * @throws IllegalArgumentException if {@code k} is negative
   * @since 14.0
   */
  public <E extends T> List<E> leastOf(Iterator<E> elements, int k) {
    checkNotNull(elements);
    checkArgument(k >= 0, "k (%s) must be nonnegative", k);

    if (k == 0 || !elements.hasNext()) {
      return ImmutableList.of();
    } else if (k >= Integer.MAX_VALUE / 2) {
      // k is really large; just do a straightforward sorted-copy-and-sublist
      ArrayList<E> list = Lists.newArrayList(elements);
      Collections.sort(list, this);
      if (list.size() > k) {
        list.subList(k, list.size()).clear();
      }
      list.trimToSize();
      return Collections.unmodifiableList(list);
    }

    /*
     * Our goal is an O(n) algorithm using only one pass and O(k) additional
     * memory.
     *
     * We use the following algorithm: maintain a buffer of size 2*k. Every time
     * the buffer gets full, find the median and partition around it, keeping
     * only the lowest k elements.  This requires n/k find-median-and-partition
     * steps, each of which take O(k) time with a traditional quickselect.
     *
     * After sorting the output, the whole algorithm is O(n + k log k). It
     * degrades gracefully for worst-case input (descending order), performs
     * competitively or wins outright for randomly ordered input, and doesn't
     * require the whole collection to fit into memory.
     */
    int bufferCap = k * 2;
    @SuppressWarnings("unchecked") // we'll only put E's in
    E[] buffer = (E[]) new Object[bufferCap];
    E threshold = elements.next();
    buffer[0] = threshold;
    int bufferSize = 1;
    // threshold is the kth smallest element seen so far.  Once bufferSize >= k,
    // anything larger than threshold can be ignored immediately.

    while (bufferSize < k && elements.hasNext()) {
      E e = elements.next();
      buffer[bufferSize++] = e;
      threshold = max(threshold, e);
    }

    while (elements.hasNext()) {
      E e = elements.next();
      if (compare(e, threshold) >= 0) {
        continue;
      }

      buffer[bufferSize++] = e;
      if (bufferSize == bufferCap) {
        // We apply the quickselect algorithm to partition about the median,
        // and then ignore the last k elements.
        int left = 0;
        int right = bufferCap - 1;

        int minThresholdPosition = 0;
        // The leftmost position at which the greatest of the k lower elements
        // -- the new value of threshold -- might be found.

        while (left < right) {
          int pivotIndex = (left + right + 1) >>> 1;
          int pivotNewIndex = partition(buffer, left, right, pivotIndex);
          if (pivotNewIndex > k) {
            right = pivotNewIndex - 1;
          } else if (pivotNewIndex < k) {
            left = Math.max(pivotNewIndex, left + 1);
            minThresholdPosition = pivotNewIndex;
          } else {
            break;
          }
        }
        bufferSize = k;

        threshold = buffer[minThresholdPosition];
        for (int i = minThresholdPosition + 1; i < bufferSize; i++) {
          threshold = max(threshold, buffer[i]);
        }
      }
    }

    Arrays.sort(buffer, 0, bufferSize, this);

    bufferSize = Math.min(bufferSize, k);
    return Collections.unmodifiableList(
        Arrays.asList(ObjectArrays.arraysCopyOf(buffer, bufferSize)));
    // We can't use ImmutableList; we have to be null-friendly!
  }

  private <E extends T> int partition(
      E[] values, int left, int right, int pivotIndex) {
    E pivotValue = values[pivotIndex];

    values[pivotIndex] = values[right];
    values[right] = pivotValue;

    int storeIndex = left;
    for (int i = left; i < right; i++) {
      if (compare(values[i], pivotValue) < 0) {
        ObjectArrays.swap(values, storeIndex, i);
        storeIndex++;
      }
    }
    ObjectArrays.swap(values, right, storeIndex);
    return storeIndex;
  }

原始碼分析

1.2.3.1 當k滿足(k >= Integer.MAX_VALUE / 2)時,採用“straightforward sorted-copy-and-sublist”,直接排序-複製和取子串的方式操作
其中排序演算法直接呼叫 Collections.sort(list, this),而其又呼叫 Arrays.sort(a, (Comparator)c);

Arrays.sort原始碼:

  public static <T> void sort(T[] a, Comparator<? super T> c) {
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a, c);
        else
            TimSort.sort(a, c);
    }

legacyMergeSort方法為傳統的歸併排序,當分到小於INSERTIONSORT_THRESHOLD(程式碼中設為7)時,採用插入排序,當大於INSERTIONSORT_THRESHOLD時採用歸併排序,程式碼可見:java.util.Arrays#mergeSort(java.lang.Object[], java.lang.Object[], int, int, int),不詳細講

Array的sort方法中還提供年了TimSort:

 TimSort.sort(a, c);

具體採用的是Tim Peters’s list sort for Python
(
TimSort
).

1.2.3.2 當k < Integer.MAX_VALUE / 2時,新建一個buffer,大小為2*k,當buffer元素小於k且有元素時,直接插入:

while (bufferSize < k && elements.hasNext()) {
  E e = elements.next();
  buffer[bufferSize++] = e;
  threshold = max(threshold, e);
}

threshold取max,max不一定是最大值,裡面呼叫了compare,compare方法重寫了,所以需要根據實際情況分析,top方法預設的max是取最小值

當buffer中元素多於k時,則與threshold比較,如果campare結果符合才插入,當buffer元素達到2*k時,會呼叫 quickselect algorithm 即快速選擇演算法,取buffer符合要求的前k個,實際沒有刪除,而是移動元素,將符合的放在buffer中的前k個,後k個後面可能會被覆蓋。

1.2.3.2.1 quickselect algorithm
quickselect algorithm 大致思路是去中間值作為劃分界限,然後遍歷buffer中元素,compare符合的放在k前面,不符合的放在k後面,裡面會呼叫partition去操作,並且返回中間值移動後的位置storeIndex,然後將該位置storeIndex再與比較k比較,如果大於k,則在left到storeIndex間繼續partition操作,如果storeIndex小於k,則在storeIndex到right間partition操作,否則正好符合要求

1.2.3.3 返回

最後在Arrays.sort方法,對buffer中的元素進行排序,最後取k個,copy返回

Arrays.sort(buffer, 0, bufferSize, this);

bufferSize = Math.min(bufferSize, k);
return Collections.unmodifiableList(
    Arrays.asList(ObjectArrays.arraysCopyOf(buffer, bufferSize)));

2.程式碼:

(1)使用

取最大的topK:

val nums = Array(4,5,3,2,1,6,7,9,8,10)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
val topK = ints.top(5)
topK.foreach(println)
assert(topK.size === 5)

輸出:

10
9
8
7
6

取最小的topK:

val nums = Array(4,5,3,2,1,6,7,9,8,10)
implicit val ord = implicitly[Ordering[Int]].reverse
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
val topK = ints.top(5)
topK.foreach(println)

輸出:

1
2
3
4
5

每個細節可以debug具體去看

3.結果:

樣例執行成功,top方法基本理解,有幾個疑問:

3.1 為什麼reduce結果toArray後要sorted?reduce返回的是有界的BoundedPriorityQueue物件,而且有序,為什麼不用reverse操作,複雜度更低?
可能情況:保證結果穩定?

程式碼:org.apache.spark.rdd.RDD#takeOrdered

  mapRDDs.reduce { (queue1, queue2) =>
      queue1 ++= queue2
      queue1
    }.toArray.sorted(ord)

3.2 快排中為什麼用2*k的buffer?為什麼不直接用有界的優先佇列?這樣操作也簡單,時間也更低?
可能情況:避免極端情況?值相同的有多個;k比較大時維護有界的成本較大?

程式碼:com.google.common.collect.Ordering#leastOf(java.util.Iterator, int)

int bufferCap = k * 2;
@SuppressWarnings("unchecked") // we'll only put E's in
E[] buffer = (E[]) new Object[bufferCap];
E threshold = elements.next();
buffer[0] = threshold;
int bufferSize = 1;

參考

【1】http://spark.apache.org/
【2】http://spark.apache.org/docs/1.5.2/programming-guide.html
【3】https://github.com/xubo245/SparkLearning
【4】book:《深入理解spark核心思想與原始碼分析》
【5】book:《spark核心原始碼分析和開發實戰》