spark原始碼解讀3之RDD中top原始碼解讀
spark原始碼解讀系列環境:spark-2.0.1 (20161103github下載版)
1.理解
輸出讀取中常用到topK演算法,RDD也提供了top方法。特別是RDD過大時,要慎用RDD的collect方法,建議使用take和top方法。如果要有序,可以使用top方法。
1.1 定義
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope {
takeOrdered(num)(ord.reverse)
}
num為要取的額個數,ord為隱式轉換,可以取最高的topK,也可以放入逆序取最低的topK,top方法呼叫的是takeOrdered方法。
1.2 原始碼理解
1.2.1 takeOrdered
top呼叫的是takeOrdered,top呼叫的是takeOrdered的原始碼為:
/** * Returns the first k (smallest) elements from this RDD as defined by the specified * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]]. * For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).takeOrdered(1) * // returns Array(2) * * sc.parallelize(Seq(2, 3, 4, 5, 6)).takeOrdered(2) * // returns Array(2, 3) * }}} * * @note this method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements */ def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope { if (num == 0) { Array.empty } else { val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => queue1 ++= queue2 queue1 }.toArray.sorted(ord) } } }
理解:
1.2.1.1 takeOrdered會使用有界的優先佇列BoundedPriorityQueue,儲存返回的k個元素。
1.2.1.2 mapPartitions是對每一個partition進行操作,對每個partition元素集合items,呼叫org.apache.spark.util.collection.takeOrdered取num個數,然後生成由若干個partition組成的mapRDDs,每個partition為大小為k的有界優先佇列queue
1.2.1.3 然後進行reduce操作,reduce是將兩個queue進行++操作,即將兩個長度為k的queue1和queue2合併成一個長為1的queue。然後進行toArray和sort(ord)。++方法為BoundedPriorityQueue類中的方法,++會呼叫+=方法進行操作:
override def ++=(xs: TraversableOnce[A]): this.type = {
xs.foreach { this += _ }
this
}
override def +=(elem: A): this.type = {
if (size < maxSize) {
underlying.offer(elem)
} else {
maybeReplaceLowest(elem)
}
this
}
具體可以檢視BoundedPriorityQueue的方法
sorted(ord)方法呼叫java.util.Arrays.sort,後面1.2.3.1 會講到
1.2.2 org.apache.spark.util.collection.takeOrdered
org.apache.spark.util.collection.takeOrdered的takeOrdered是呼叫的com.google.common.collect.{Ordering => GuavaOrdering}
的方法,並且重寫了compare方法,主要是Ordering預設的是從小到大,而top預設是取最大的num個元素
val ordering = new GuavaOrdering[T] {
override def compare(l: T, r: T): Int = ord.compare(l, r)
}
然後再呼叫 ordering的leastOf方法,中間有java和scala的iterator容器的相互轉換:
ordering.leastOf(input.asJava, num).iterator.asScala
1.2.3 com.google.common.collect.Ordering的leastOf方法
包的引入方式:在maven的pom檔案中加入
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>14.0.1</version>
<scope>provided</scope>
</dependency>
原始碼:
/**
* Returns the {@code k} least elements from the given iterator according to
* this ordering, in order from least to greatest. If there are fewer than
* {@code k} elements present, all will be included.
*
* <p>The implementation does not necessarily use a <i>stable</i> sorting
* algorithm; when multiple elements are equivalent, it is undefined which
* will come first.
*
* @return an immutable {@code RandomAccess} list of the {@code k} least
* elements in ascending order
* @throws IllegalArgumentException if {@code k} is negative
* @since 14.0
*/
public <E extends T> List<E> leastOf(Iterator<E> elements, int k) {
checkNotNull(elements);
checkArgument(k >= 0, "k (%s) must be nonnegative", k);
if (k == 0 || !elements.hasNext()) {
return ImmutableList.of();
} else if (k >= Integer.MAX_VALUE / 2) {
// k is really large; just do a straightforward sorted-copy-and-sublist
ArrayList<E> list = Lists.newArrayList(elements);
Collections.sort(list, this);
if (list.size() > k) {
list.subList(k, list.size()).clear();
}
list.trimToSize();
return Collections.unmodifiableList(list);
}
/*
* Our goal is an O(n) algorithm using only one pass and O(k) additional
* memory.
*
* We use the following algorithm: maintain a buffer of size 2*k. Every time
* the buffer gets full, find the median and partition around it, keeping
* only the lowest k elements. This requires n/k find-median-and-partition
* steps, each of which take O(k) time with a traditional quickselect.
*
* After sorting the output, the whole algorithm is O(n + k log k). It
* degrades gracefully for worst-case input (descending order), performs
* competitively or wins outright for randomly ordered input, and doesn't
* require the whole collection to fit into memory.
*/
int bufferCap = k * 2;
@SuppressWarnings("unchecked") // we'll only put E's in
E[] buffer = (E[]) new Object[bufferCap];
E threshold = elements.next();
buffer[0] = threshold;
int bufferSize = 1;
// threshold is the kth smallest element seen so far. Once bufferSize >= k,
// anything larger than threshold can be ignored immediately.
while (bufferSize < k && elements.hasNext()) {
E e = elements.next();
buffer[bufferSize++] = e;
threshold = max(threshold, e);
}
while (elements.hasNext()) {
E e = elements.next();
if (compare(e, threshold) >= 0) {
continue;
}
buffer[bufferSize++] = e;
if (bufferSize == bufferCap) {
// We apply the quickselect algorithm to partition about the median,
// and then ignore the last k elements.
int left = 0;
int right = bufferCap - 1;
int minThresholdPosition = 0;
// The leftmost position at which the greatest of the k lower elements
// -- the new value of threshold -- might be found.
while (left < right) {
int pivotIndex = (left + right + 1) >>> 1;
int pivotNewIndex = partition(buffer, left, right, pivotIndex);
if (pivotNewIndex > k) {
right = pivotNewIndex - 1;
} else if (pivotNewIndex < k) {
left = Math.max(pivotNewIndex, left + 1);
minThresholdPosition = pivotNewIndex;
} else {
break;
}
}
bufferSize = k;
threshold = buffer[minThresholdPosition];
for (int i = minThresholdPosition + 1; i < bufferSize; i++) {
threshold = max(threshold, buffer[i]);
}
}
}
Arrays.sort(buffer, 0, bufferSize, this);
bufferSize = Math.min(bufferSize, k);
return Collections.unmodifiableList(
Arrays.asList(ObjectArrays.arraysCopyOf(buffer, bufferSize)));
// We can't use ImmutableList; we have to be null-friendly!
}
private <E extends T> int partition(
E[] values, int left, int right, int pivotIndex) {
E pivotValue = values[pivotIndex];
values[pivotIndex] = values[right];
values[right] = pivotValue;
int storeIndex = left;
for (int i = left; i < right; i++) {
if (compare(values[i], pivotValue) < 0) {
ObjectArrays.swap(values, storeIndex, i);
storeIndex++;
}
}
ObjectArrays.swap(values, right, storeIndex);
return storeIndex;
}
原始碼分析
1.2.3.1 當k滿足(k >= Integer.MAX_VALUE / 2)時,採用“straightforward sorted-copy-and-sublist”,直接排序-複製和取子串的方式操作
其中排序演算法直接呼叫 Collections.sort(list, this),而其又呼叫 Arrays.sort(a, (Comparator)c);
Arrays.sort原始碼:
public static <T> void sort(T[] a, Comparator<? super T> c) {
if (LegacyMergeSort.userRequested)
legacyMergeSort(a, c);
else
TimSort.sort(a, c);
}
legacyMergeSort方法為傳統的歸併排序,當分到小於INSERTIONSORT_THRESHOLD(程式碼中設為7)時,採用插入排序,當大於INSERTIONSORT_THRESHOLD時採用歸併排序,程式碼可見:java.util.Arrays#mergeSort(java.lang.Object[], java.lang.Object[], int, int, int),不詳細講
Array的sort方法中還提供年了TimSort:
TimSort.sort(a, c);
具體採用的是Tim Peters’s list sort for Python
(
TimSort).
1.2.3.2 當k < Integer.MAX_VALUE / 2時,新建一個buffer,大小為2*k,當buffer元素小於k且有元素時,直接插入:
while (bufferSize < k && elements.hasNext()) {
E e = elements.next();
buffer[bufferSize++] = e;
threshold = max(threshold, e);
}
threshold取max,max不一定是最大值,裡面呼叫了compare,compare方法重寫了,所以需要根據實際情況分析,top方法預設的max是取最小值
當buffer中元素多於k時,則與threshold比較,如果campare結果符合才插入,當buffer元素達到2*k時,會呼叫 quickselect algorithm 即快速選擇演算法,取buffer符合要求的前k個,實際沒有刪除,而是移動元素,將符合的放在buffer中的前k個,後k個後面可能會被覆蓋。
1.2.3.2.1 quickselect algorithm
quickselect algorithm 大致思路是去中間值作為劃分界限,然後遍歷buffer中元素,compare符合的放在k前面,不符合的放在k後面,裡面會呼叫partition去操作,並且返回中間值移動後的位置storeIndex,然後將該位置storeIndex再與比較k比較,如果大於k,則在left到storeIndex間繼續partition操作,如果storeIndex小於k,則在storeIndex到right間partition操作,否則正好符合要求
1.2.3.3 返回
最後在Arrays.sort方法,對buffer中的元素進行排序,最後取k個,copy返回
Arrays.sort(buffer, 0, bufferSize, this);
bufferSize = Math.min(bufferSize, k);
return Collections.unmodifiableList(
Arrays.asList(ObjectArrays.arraysCopyOf(buffer, bufferSize)));
2.程式碼:
(1)使用
取最大的topK:
val nums = Array(4,5,3,2,1,6,7,9,8,10)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
val topK = ints.top(5)
topK.foreach(println)
assert(topK.size === 5)
輸出:
10
9
8
7
6
取最小的topK:
val nums = Array(4,5,3,2,1,6,7,9,8,10)
implicit val ord = implicitly[Ordering[Int]].reverse
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
val topK = ints.top(5)
topK.foreach(println)
輸出:
1
2
3
4
5
每個細節可以debug具體去看
3.結果:
樣例執行成功,top方法基本理解,有幾個疑問:
3.1 為什麼reduce結果toArray後要sorted?reduce返回的是有界的BoundedPriorityQueue物件,而且有序,為什麼不用reverse操作,複雜度更低?
可能情況:保證結果穩定?
程式碼:org.apache.spark.rdd.RDD#takeOrdered
mapRDDs.reduce { (queue1, queue2) =>
queue1 ++= queue2
queue1
}.toArray.sorted(ord)
3.2 快排中為什麼用2*k的buffer?為什麼不直接用有界的優先佇列?這樣操作也簡單,時間也更低?
可能情況:避免極端情況?值相同的有多個;k比較大時維護有界的成本較大?
程式碼:com.google.common.collect.Ordering#leastOf(java.util.Iterator, int)
int bufferCap = k * 2;
@SuppressWarnings("unchecked") // we'll only put E's in
E[] buffer = (E[]) new Object[bufferCap];
E threshold = elements.next();
buffer[0] = threshold;
int bufferSize = 1;
參考
【1】http://spark.apache.org/
【2】http://spark.apache.org/docs/1.5.2/programming-guide.html
【3】https://github.com/xubo245/SparkLearning
【4】book:《深入理解spark核心思想與原始碼分析》
【5】book:《spark核心原始碼分析和開發實戰》