1. 程式人生 > >spark原始碼閱讀--shuffle讀過程原始碼分析

spark原始碼閱讀--shuffle讀過程原始碼分析

shuffle讀過程原始碼分析

上一篇中,我們分析了shuffle在map階段的寫過程。簡單回顧一下,主要是將ShuffleMapTask計算的結果資料在記憶體中按照分割槽和key進行排序,過程中由於記憶體限制會溢寫出多個磁碟檔案,最後會對所有的檔案和記憶體中剩餘的資料進行歸併排序並溢寫到一個檔案中,同時會記錄每個分割槽(reduce端分割槽)的資料在檔案中的偏移,並且把分割槽和偏移的對映關係寫到一個索引檔案中。
好了,簡單回顧了寫過程後,我們不禁思考,reduce階段的資料讀取的具體過程是什麼樣的?資料讀取的發生的時機是什麼?

首先應該回答後一個問題:資料讀取發生的時機是什麼?我們知道,rdd的計算鏈根據shuffle被切分為不同的stage,一個stage的開始階段一般就是從讀取上一階段的資料開始,也就是說stage讀取資料的過程其實就是reduce過程,然後經過該stage的計算鏈後得到結果資料,再然後就會把這些資料寫入到磁碟供下一個stage讀取,這個寫入的過程實際上就是map輸出過程,而這個過程我們之前已經分析過了。本篇我們要分析的是reduce階段讀取資料的過程。

囉嗦了這麼一大段,其實就是為了引出資料讀取的入口,還是要回到ShuffleMapTask,這裡我只貼部分程式碼:

  // shuffle管理器
  val manager = SparkEnv.get.shuffleManager
  // 獲取一個shuffle寫入器
  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
  // 這裡可以看到rdd計算的核心方法就是iterator方法
  // SortShuffleWriter的write方法可以分為幾個步驟:
  // 將上游rdd計算出的資料(通過呼叫rdd.iterator方法)寫入記憶體緩衝區,
  // 在寫的過程中如果超過 記憶體閾值就會溢寫磁碟檔案,可能會寫多個檔案
  // 最後將溢寫的檔案和記憶體中剩餘的資料一起進行歸併排序後寫入到磁碟中形成一個大的資料檔案
  // 這個排序是先按分割槽排序,在按key排序
  // 在最後歸併排序後寫的過程中,沒寫一個分割槽就會手動刷寫一遍,並記錄下這個分割槽資料在檔案中的位移
  // 所以實際上最後寫完一個task的資料後,磁碟上會有兩個檔案:資料檔案和記錄每個reduce端partition資料位移的索引檔案
  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  // 主要是刪除中間過程的溢寫檔案,向記憶體管理器釋放申請的記憶體
  writer.stop(success = true).get

讀取資料的程式碼其實就是rdd.iterator(partition, context),
iterator方法主要是處理rdd快取的邏輯,如果有快取就會從快取中讀取(通過BlockManager),如果沒有快取就會進行實際的計算,發現最終呼叫RDD.compute方法進行實際的計算,這個方法是一個抽象方法,是由子類實現的具體的計算邏輯,使用者程式碼中對於RDD做的一些變換操作實際上最終都會體現在compute方法中。
另一方面,我們知道,map,filter這類運算元不是shuffle操作,不會導致stage的劃分,所以我們想看shuffle讀過程就要找一個Shuffle型別的操作,我們看一下RDD.groupBy,最終呼叫了groupByKey方法

RDD.groupByKey

def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope {
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
val createCombiner = (v: V) => CompactBuffer(v)
val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v
val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2
val bufs = combineByKeyWithClassTag[CompactBuffer[V]](
  createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false)
bufs.asInstanceOf[RDD[(K, Iterable[V])]]
}

最終呼叫了combineByKeyWithClassTag

RDD.combineByKeyWithClassTag

做一些判斷,檢查一些非法情況,然後處理一下分割槽器,最後返回一個ShuffledRDD,所以接下來我們分析一下ShuffleRDD的compute方法

def combineByKeyWithClassTag[C](
  createCombiner: V => C,
  mergeValue: (C, V) => C,
  mergeCombiners: (C, C) => C,
  partitioner: Partitioner,
  mapSideCombine: Boolean = true,
  serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
// 如果key是Array型別,是不支援在map端合併的
// 並且也不支援HashPartitioner
if (keyClass.isArray) {
  if (mapSideCombine) {
    throw new SparkException("Cannot use map-side combining with array keys.")
  }
  if (partitioner.isInstanceOf[HashPartitioner]) {
    throw new SparkException("HashPartitioner cannot partition array keys.")
  }
}
// 聚合器,用於對資料進行聚合
val aggregator = new Aggregator[K, V, C](
  self.context.clean(createCombiner),
  self.context.clean(mergeValue),
  self.context.clean(mergeCombiners))
// 如果分割槽器相同,就不需要shuffle了
if (self.partitioner == Some(partitioner)) {
  self.mapPartitions(iter => {
    val context = TaskContext.get()
    new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
  }, preservesPartitioning = true)
} else {
  // 返回一個ShuffledRDD
  new ShuffledRDD[K, V, C](self, partitioner)
    .setSerializer(serializer)
    .setAggregator(aggregator)
    .setMapSideCombine(mapSideCombine)
}
}

ShuffleRDD.compute

通過shuffleManager獲取一個讀取器,資料讀取的邏輯在讀取器裡。

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
// 通過shuffleManager獲取一個讀取器
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
  .read()
  .asInstanceOf[Iterator[(K, C)]]
}

SortShuffleManager.getReader

無需多說,直接看BlockStoreShuffleReader

override def getReader[K, C](
  handle: ShuffleHandle,
  startPartition: Int,
  endPartition: Int,
  context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
  handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

BlockStoreShuffleReader.read

顯然,這個方法才是核心所在。總結一下主要步驟:

  • 獲取一個包裝的迭代器ShuffleBlockFetcherIterator,它迭代的元素是blockId和這個block對應的讀取流,很顯然這個類就是實現reduce階段資料讀取的關鍵
  • 將原始讀取流轉換成反序列化後的迭代器
  • 將迭代器轉換成能夠統計度量值的迭代器,這一系列的轉換和java中對於流的各種裝飾器很類似
  • 將迭代器包裝成能夠相應中斷的迭代器。每讀一條資料就會檢查一下任務有沒有被殺死,這種做法是為了儘量及時地響應殺死任務的請求,比如從driver端發來殺死任務的訊息。
  • 利用聚合器對結果進行聚合。這裡再次利用了AppendonlyMap這個資料結構,前面shuffle寫階段也用到這個資料結構,它的內部是一個以陣列作為底層資料結構的,以線性探測法線性的hash表。
  • 最後對結果進行排序。

所以很顯然,我們想知道的shuffle讀取資料的具體邏輯就藏在ShuffleBlockFetcherIterator中

    private[spark] class BlockStoreShuffleReader[K, C](
        handle: BaseShuffleHandle[K, _, C],
        startPartition: Int,
        endPartition: Int,
        context: TaskContext,
        serializerManager: SerializerManager = SparkEnv.get.serializerManager,
        blockManager: BlockManager = SparkEnv.get.blockManager,
        mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
      extends ShuffleReader[K, C] with Logging {
    
      private val dep = handle.dependency
    
      /** Read the combined key-values for this reduce task */
      override def read(): Iterator[Product2[K, C]] = {
        // 獲取一個包裝的迭代器,它迭代的元素是blockId和這個block對應的讀取流
        val wrappedStreams = new ShuffleBlockFetcherIterator(
          context,
          blockManager.shuffleClient,
          blockManager,
          mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
          serializerManager.wrapStream,
          // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
          SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
          SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
          SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
          SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
          SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
    
        val serializerInstance = dep.serializer.newInstance()
    
        // Create a key/value iterator for each stream
        // 將原始讀取流轉換成反序列化後的迭代器
        val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
          // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
          // NextIterator. The NextIterator makes sure that close() is called on the
          // underlying InputStream when all records have been read.
          serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
        }
    
        // Update the context task metrics for each record read.
        val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
        // 轉換成能夠統計度量值的迭代器,這一系列的轉換和java中對於流的各種裝飾器很類似
        val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
          recordIter.map { record =>
            readMetrics.incRecordsRead(1)
            record
          },
          context.taskMetrics().mergeShuffleReadMetrics())
    
        // An interruptible iterator must be used here in order to support task cancellation
        // 每讀一條資料就會檢查一下任務有沒有被殺死,
        // 這種做法是為了儘量及時地響應殺死任務的請求,比如從driver端發來殺死任務的訊息
        val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
        val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
          // 利用聚合器對結果進行聚合
          if (dep.mapSideCombine) {
            // We are reading values that are already combined
            val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
            dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            // We don't know the value type, but also don't care -- the dependency *should*
            // have made sure its compatible w/ this aggregator, which will convert the value
            // type to the combined type C
            val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
            dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
          }
        } else {
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
          interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
        }
    
        // Sort the output if there is a sort ordering defined.
        // 最後對結果進行排序
        dep.keyOrdering match {
          case Some(keyOrd: Ordering[K]) =>
            // Create an ExternalSorter to sort the data.
            val sorter =
              new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
            sorter.insertAll(aggregatedIter)
            context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
            context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
            context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
            CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
          case None =>
            aggregatedIter
        }
      }
    }

ShuffleBlockFetcherIterator

這個類比較複雜,仔細看在類初始化的程式碼中會呼叫initialize方法。
其次,我們應該注意它的構造器中的引數,

    val wrappedStreams = new ShuffleBlockFetcherIterator(
    context,
    // 如果沒有啟用外部shuffle服務,就是BlockTransferService
    blockManager.shuffleClient,
    blockManager,
    // 通過mapOutputTracker元件獲取每個分割槽對應的資料block的物理位置
    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
    serializerManager.wrapStream,
    // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
    // 獲取幾個配置引數
    SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
    SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
    SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
    SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
    SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

ShuffleBlockFetcherIterator.initialize

  • 首先將本地的block和遠端的block分隔開
  • 然後開始傳送請求拉取遠端資料。這個過程中會有一些約束條件限制拉取資料請求的數量,主要是正在獲取的總資料量的限制,請求併發數限制;每個遠端地址同時拉取的塊數也會有限制,但是這個閾值預設是Integer.MAX_VALUE
  • 獲取本地的block資料

其中,獲取本地資料較為簡單,主要就是通過本節點的BlockManager來獲取塊資料,並通過索引檔案獲取資料指定分割槽的資料。
我們著重分析遠端拉取的部分

private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
// 向TaskContext中新增一個回撥,在任務完成時做一些清理工作
context.addTaskCompletionListener(_ => cleanup())

// Split local and remote blocks.
// 將本地的block和遠端的block分隔開
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
  "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
  ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

// Send out initial requests for blocks, up to our maxBytesInFlight
// 傳送遠端拉取資料的請求
// 儘可能多地傳送請求
// 但是會有一定的約束:
// 全域性性的約束,全域性拉取資料的rpc執行緒併發數,全域性拉取資料的資料量限制
// 每個遠端地址的限制:每個遠端地址同時拉取的塊數不能超過一定閾值
fetchUpToMaxBytes()

// 記錄已經發送的請求個數,仍然會有一部分沒有傳送請求
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

// Get Local Blocks
// 獲取本地的block資料
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}

ShuffleBlockFetcherIterator.splitLocalRemoteBlocks

我們首先來看如何切分遠端和本地的資料塊,總結一下這個方法:

  • 首先將同時拉取的資料量的大小除以5作為每次請求拉取的資料量的限制,這麼做的原因是為了允許同時從5個節點拉取資料,因為節點的網路環境可能並不穩定,同時從多個節點拉取資料有助於減少網路波動對效能帶來的影響,而對整體的同時拉取資料量的限制主要是為了限制本機網路流量的使用
  • 迴圈遍歷每一個節點地址(這裡是BlockManagerId),
  • 如果地址與本機地址相同,那麼對應的blocks就是本地block
  • 對於遠端block,則要根據同時拉取資料量大小的限制將每個節點的所有block切分成多個請求(FetchRequest),確保這些請求單次的拉取資料量不會太大

      private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
      // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
      // nodes, rather than blocking on reading output from one node.
      // 之所以將請求大小減小到maxBytesInFlight / 5,
      // 是為了並行化地拉取資料,最毒允許同時從5個節點拉取資料
      val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
      logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
        + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
    
      // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
      // at most maxBytesInFlight in order to limit the amount of data in flight.
      val remoteRequests = new ArrayBuffer[FetchRequest]
    
      // Tracks total number of blocks (including zero sized blocks)
      // 記錄總的block數量
      var totalBlocks = 0
      for ((address, blockInfos) <- blocksByAddress) {
        totalBlocks += blockInfos.size
        // 如果地址與本地的BlockManager相同,就是本地block
        if (address.executorId == blockManager.blockManagerId.executorId) {
          // Filter out zero-sized blocks
          localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
          numBlocksToFetch += localBlocks.size
        } else {
          val iterator = blockInfos.iterator
          var curRequestSize = 0L
          var curBlocks = new ArrayBuffer[(BlockId, Long)]
          while (iterator.hasNext) {
            val (blockId, size) = iterator.next()
            // Skip empty blocks
            if (size > 0) {
              curBlocks += ((blockId, size))
              remoteBlocks += blockId
              numBlocksToFetch += 1
              curRequestSize += size
            } else if (size < 0) {
              throw new BlockException(blockId, "Negative block size " + size)
            }
            // 如果超過每次請求的資料量限制,那麼建立一次請求
            if (curRequestSize >= targetRequestSize ||
                curBlocks.size >= maxBlocksInFlightPerAddress) {
              // Add this FetchRequest
              remoteRequests += new FetchRequest(address, curBlocks)
              logDebug(s"Creating fetch request of $curRequestSize at $address "
                + s"with ${curBlocks.size} blocks")
              curBlocks = new ArrayBuffer[(BlockId, Long)]
              curRequestSize = 0
            }
          }
          // Add in the final request
          // 掃尾方法,最後剩餘的塊建立一次請求
          if (curBlocks.nonEmpty) {
            remoteRequests += new FetchRequest(address, curBlocks)
          }
        }
      }
      logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
      remoteRequests
      }

ShuffleBlockFetcherIterator.fetchUpToMaxBytes

回到initialize方法中,在完成本地與遠端block的切分後,我們得到了一批封裝好的資料拉取請求,將這些請求加到佇列中,接下來要做的是通過rpc客戶端傳送這些請求,

這個方法邏輯還是相對簡單,主要邏輯就是兩個迴圈,先發送延緩佇列中的請求,然後傳送正常的請求;之所以會有延緩佇列是因為這些請求在第一次待發送時因為資料量超過閾值或者請求數量超過閾值而不能傳送,所以就被放到延緩佇列中,而這裡的處理也是優先發送延緩佇列中的請求。每個請求在傳送前必須要滿足下面幾個條件才會被髮送:

  • 當前正在拉取的資料量不能超過閾值maxReqsInFlight(預設48m);這裡會有一個問題,如果某個block的資料量超過maxReqsInFlight值呢?這種情況下會等當前已經沒有進行中的資料拉取請求才會發送這個請求,因為在對當前請求資料量閾值進行判斷時會檢查bytesInFlight == 0,如果這個條件滿足就不會檢查本次請求的資料量是否會超過閾值。
  • 當前正在拉取的請求資料量不能超過閾值(預設Int.MaxValue)
  • 每個遠端地址的同時請求數量也會有限制(預設Int.MaxValue)
  • 最後符合條件的請求就會被髮送,這裡要提出的一點是如果一次請求的資料量超過maxReqSizeShuffleToMem值,那麼就會寫入磁碟的一個臨時檔案中,而這個閾值的預設值是Long.MaxValue,所以預設情況下是沒有限制的。

      // 傳送請求
      // 儘可能多地傳送請求
      // 但是會有一定的約束:
      // 全域性性的約束,全域性拉取資料的rpc執行緒併發數,全域性拉取資料的資料量限制
      // 每個遠端地址的限制:每個遠端地址同時拉取的塊數不能超過一定閾值
      private def fetchUpToMaxBytes(): Unit = {
      // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
      // immediately, defer the request until the next time it can be processed.
    
      // Process any outstanding deferred fetch requests if possible.
      if (deferredFetchRequests.nonEmpty) {
        for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
          while (isRemoteBlockFetchable(defReqQueue) &&
              !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
            val request = defReqQueue.dequeue()
            logDebug(s"Processing deferred fetch request for $remoteAddress with "
              + s"${request.blocks.length} blocks")
            send(remoteAddress, request)
            if (defReqQueue.isEmpty) {
              deferredFetchRequests -= remoteAddress
            }
          }
        }
      }
    
      // Process any regular fetch requests if possible.
      while (isRemoteBlockFetchable(fetchRequests)) {
        val request = fetchRequests.dequeue()
        val remoteAddress = request.address
        // 如果超過了同時拉取的塊數的限制,那麼將這個請求放到延緩佇列中,留待下次請求
        if (isRemoteAddressMaxedOut(remoteAddress, request)) {
          logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
          val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
          defReqQueue.enqueue(request)
          deferredFetchRequests(remoteAddress) = defReqQueue
        } else {
          send(remoteAddress, request)
        }
      }
    
      // 傳送一個請求,並且累加記錄請求的塊的數量,
      // 以用於在下次請求時檢查請求塊的數量是否超過閾值
      def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
        sendRequest(request)
        numBlocksInFlightPerAddress(remoteAddress) =
          numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
      }
    
      // 這個限制是對所有的請求而言,不分具體是哪個遠端節點
      // 檢查當前的請求的數量是否還有餘量
      // 當前請求的大小是否還有餘量
      // 這主要是為了限制併發數和網路流量的使用
      def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
        fetchReqQueue.nonEmpty &&
          (bytesInFlight == 0 ||
            (reqsInFlight + 1 <= maxReqsInFlight &&
              bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
      }
    
      // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
      // given remote address.
      // 檢測正在拉取的塊的數量是否超過閾值
      // 每個地址都有一個同事拉取塊數的限制
      def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
          maxBlocksInFlightPerAddress
      }
      }

ShuffleBlockFetcherIterator.next

通過上一個方法的分析,我們能夠看出來,初始化時發起的拉取資料的請求並未將所有請求全部發送出去,並且還會有請求因為超過閾值而被放入延緩佇列中,那麼這些未傳送的請求是什麼時候被再次傳送的呢?答案就在next方法中。我們知道ShuffleBlockFetcherIterator是一個迭代器,所以外部呼叫者對元素的訪問是通過next方法,所以很容易想到next方法中肯定會有傳送拉取資料請求的邏輯。
總結一下:

  • 首先從結果佇列中獲取一個拉取成功的結果(結果佇列是一個阻塞佇列,如果沒有拉取成功的結果會阻塞呼叫者)
  • 拿到一個結果後檢查這個結果是拉取成功還是拉取失敗,如果失敗則直接拋異常(重試的邏輯實在rpc客戶端實現的,不是在這裡實現)
  • 如果是一個成功的結果,首先要更新一下一些任務度量值,更新一些內部的簿記量,如正在拉取的資料量
  • 將拉取到的位元組緩衝包裝成一個位元組輸入流
  • 通過外部傳進來的函式對流再包裝一次,通過外部傳進來的函式再包裝一次,一般是解壓縮和解密
  • 而且流被壓縮或者加密過,如果塊的大小比較小,那麼要將這個流拷貝一份,這樣就會實際出發解壓縮和解密,以此來儘早暴露塊損壞的 問題
  • 最後一句關鍵語句,再次發起一輪拉取資料請求的發 送,因為經過next處理之後,已經有拉取成功的資料了,正在拉取的資料量和請求數量可能減小了,這就為傳送新的請求騰出空間

      override def next(): (BlockId, InputStream) = {
      if (!hasNext) {
        throw new NoSuchElementException
      }
    
      numBlocksProcessed += 1
    
      var result: FetchResult = null
      var input: InputStream = null
      // Take the next fetched result and try to decompress it to detect data corruption,
      // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
      // is also corrupt, so the previous stage could be retried.
      // For local shuffle block, throw FailureFetchResult for the first IOException.
      while (result == null) {
        val startFetchWait = System.currentTimeMillis()
        result = results.take()
        val stopFetchWait = System.currentTimeMillis()
        shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
    
        result match {
          case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
            if (address != blockManager.blockManagerId) {
              numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
              // 主要是更新一些度量值
              shuffleMetrics.incRemoteBytesRead(buf.size)
              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
              }
              shuffleMetrics.incRemoteBlocksFetched(1)
            }
            bytesInFlight -= size
            if (isNetworkReqDone) {
              reqsInFlight -= 1
              logDebug("Number of requests in flight " + reqsInFlight)
            }
    
            // 將位元組緩衝包裝成一個位元組輸入流
            val in = try {
              buf.createInputStream()
            } catch {
              // The exception could only be throwed by local shuffle block
              case e: IOException =>
                assert(buf.isInstanceOf[FileSegmentManagedBuffer])
                logError("Failed to create input stream from local block", e)
                buf.release()
                throwFetchFailedException(blockId, address, e)
            }
    
            // 通過外部傳進來的函式再包裝一次,一般是增加壓縮和加密的功能
            input = streamWrapper(blockId, in)
            // Only copy the stream if it's wrapped by compression or encryption, also the size of
            // block is small (the decompressed block is smaller than maxBytesInFlight)
            // 如果塊的大小比較小,而且流被壓縮或者加密過,那麼需要將這個流拷貝一份
            if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
              val originalInput = input
              val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
              try {
                // Decompress the whole block at once to detect any corruption, which could increase
                // the memory usage tne potential increase the chance of OOM.
                // TODO: manage the memory used here, and spill it into disk in case of OOM.
                Utils.copyStream(input, out)
                out.close()
                input = out.toChunkedByteBuffer.toInputStream(dispose = true)
              } catch {
                case e: IOException =>
                  buf.release()
                  if (buf.isInstanceOf[FileSegmentManagedBuffer]
                    || corruptedBlocks.contains(blockId)) {
                    throwFetchFailedException(blockId, address, e)
                  } else {
                    logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
                    corruptedBlocks += blockId
                    fetchRequests += FetchRequest(address, Array((blockId, size)))
                    result = null
                  }
              } finally {
                // TODO: release the buf here to free memory earlier
                originalInput.close()
                in.close()
              }
            }
    
            // 拉取失敗,拋異常
            // 這裡思考一下:拉取塊資料肯定是有重試機制的,但是這裡拉取失敗之後直接拋異常是為何??
            // 答案是:重試機制並不是正在這裡實現 的,而是在rpc客戶端傳送拉取請求時實現了重試機制
            // 也就是說如果到這裡是失敗的話,說明已經經過重試後還是失敗的,所以這裡直接拋異常就行了
          case FailureFetchResult(blockId, address, e) =>
            throwFetchFailedException(blockId, address, e)
        }
    
        // Send fetch requests up to maxBytesInFlight
        // 這裡再次傳送拉取請求,因為前面已經有成功拉取到的資料,
        // 所以正在拉取中的資料量就會減小,所以就能為新的請求騰出空間
        fetchUpToMaxBytes()
      }
    
      currentResult = result.asInstanceOf[SuccessFetchResult]
      (currentResult.blockId, new BufferReleasingInputStream(input, this))
      }

總結

到此,我們就把shuffle讀的過程大概分析完了。整體下來,感覺主幹邏輯不是很複雜,但是裡面有很多細碎邏輯,所以上面的分析還是比較碎,這裡把整個過程的主幹邏輯再提煉一下,以便能有個整體的認識:

  • 首先,在一些shuffle型別的RDD中,它的計算方法compute會通過ShuffleManager獲取一個block資料讀取器BlockStoreShuffleReader
  • 通過BlockStoreShuffleReader中的read方法進行資料的讀取,一個reduce端分割槽的資料一般會依賴於所有的map端輸出的分割槽資料,所以資料一般會在多個executor(注意是executor節點,通過BlockManagerId唯一標識,一個物理節點可能會執行多個executor節點)節點上,而且每個executor節點也可能會有多個block,在shuffle寫過程的分析中我們也提到,每個map最後時輸出一個數據檔案和索引檔案,也就是一個block,但是因為一個節點
  • 這個方法通過ShuffleBlockFetcherIterator物件封裝了遠端拉取資料的複雜邏輯,並且最終將拉取到的資料封裝成流的迭代器的形式
  • 對所有的block的流進行層層裝飾,包括反序列化,任務度量值(讀入資料條數)統計,每條資料可中斷,
  • 對資料進行聚合
  • 對聚合後的資料進行排序

所以,從這裡我們也能看出來,新版的shuffle機制中,也就是SortShuffleManager,使用者程式碼對於shuffle之後的rdd拿到的是經過排序的資料(如果指定排序器的話)