Spark Shuffle模組——Suffle Read過程分析
Spark Shuffle Read呼叫棧如下:
1. org.apache.spark.rdd.ShuffledRDD#compute()
2. org.apache.spark.shuffle.ShuffleManager#getReader()
3. org.apache.spark.shuffle.hash.HashShuffleReader#read()
4. org.apache.spark.storage.ShuffleBlockFetcherIterator#initialize()
5. org.apache.spark.storage.ShuffleBlockFetcherIterator#splitLocalRemoteBlocks()
org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest()
org.apache.spark.storage.ShuffleBlockFetcherIterator#fetchLocalBlocks()
下面是fetchLocalBlocks()方法執行時涉及到的類和對應方法:
6. org.apache.spark.storage.BlockManager#getBlockData()
org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver()
ShuffleManager有兩個子類,如果是HashShuffle 則對應的是org.apache.spark.shuffle.hash.HashShuffleManager#shuffleBlockResolver()方法,該方法返回的是org.apache.spark.shuffle.FileShuffleBlockResolver,再呼叫FileShuffleBlockResolver#getBlockData()方法返回Block資料
;如果是Sort Shuffle,則對應的是
org.apache.spark.shuffle.hash.SortShuffleManager#shuffleBlockResolver(),該方法返回的是org.apache.spark.shuffle.IndexShuffleBlockResolver,然後再呼叫IndexShuffleBlockResolver#getBlockData()返回Block資料。
下面是org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest()方法執行時涉及到的類和對應方法
7.
org.apache.spark.network.shuffle.ShuffleClient#fetchBlocks
org.apache.spark.network.shuffle.ShuffleClient有兩個子類,分別是ExternalShuffleClient及BlockTransferService
,其中org.apache.spark.network.shuffle.BlockTransferService又有兩個子類,分別是NettyBlockTransferService和NioBlockTransferService,對應兩種不同遠端獲取Block資料方式,Spark 1.5.2中已經將NioBlockTransferService方式設定為deprecated,在後續版本中將被移除
下面按上述呼叫棧對各方法進行說明,這裡只講脈絡,細節後面再討論
ShuffledRDD#compute()程式碼
Task執行時,呼叫ShuffledRDD的compute方法,其程式碼如下:
//org.apache.spark.rdd.ShuffledRDD#compute()
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
//通過org.apache.spark.shuffle.ShuffleManager#getReader()方法
//無論是Sort Shuffle 還是 Hash Shuffle,使用的都是
//org.apache.spark.shuffle.hash.HashShuffleReader
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
可以看到,其核心邏輯是通過呼叫ShuffleManager#getReader()方法得到HashShuffleReader物件,然後呼叫HashShuffleReader#read()方法完成前一Stage中ShuffleMapTask生成的Shuffle 資料的讀取。需要說明的是,無論是Hash Shuffle還是Sort Shuffle,使用的都是HashShuffleReader。
HashShuffleReader#read()
跳到HashShuffleReader#read()方法當中,其原始碼如下:
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
//建立ShuffleBlockFetcherIterator物件,在其建構函式中會呼叫initialize()方法
//該方法中會執行splitLocalRemoteBlocks(),確定資料的讀取策略
//遠端資料呼叫sendRequest()方法讀取
//本地資料呼叫fetchLocalBlocks()方法讀取
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// 讀取Map端已經聚合的資料
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
//讀取Reducer端聚合的資料
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// 對輸出結果進行排序
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
sorter.iterator
case None =>
aggregatedIter
}
}
ShuffleBlockFetcherIterator#splitLocalRemoteBlocks()
splitLocalRemoteBlocks()方法確定資料的讀取策略,localBlocks變數記錄在本地機器的BlockID,remoteBlocks變數則用於記錄所有在遠端機器上的BlockID。遠端資料塊被分割成最大為maxSizeInFlight大小的FetchRequests
val remoteRequests = new ArrayBuffer[FetchRequest]
splitLocalRemoteBlocks()方法具有原始碼如下:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
//maxBytesInFlight為每次請求的最大資料量,預設值為48M
//通過SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)進行設定
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
// Tracks total number of blocks (including zero sized blocks)
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
//要獲取的資料在本地
if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
//記錄資料在本地的BlockID
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
} else {
//資料不在本地時
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
//記錄資料在遠端機器上的BlockID
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= targetRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curBlocks = new ArrayBuffer[(BlockId, Long)]
logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
remoteRequests
}
ShuffleBlockFetcherIterator#fetchLocalBlocks()
fetchLocalBlocks()方法進行本地Block的讀取,呼叫的是BlockManager的getBlockData方法,其原始碼如下:
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
try {
//呼叫BlockManager的getBlockData方法
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
}
跳轉到BlockManager的getBlockData方法,可以看到其原始碼如下:
override def getBlockData(blockId: BlockId): ManagedBuffer = {
if (blockId.isShuffle) {
//先呼叫的是ShuffleManager的shuffleBlockResolver方法,得到ShuffleBlockResolver
//然後再呼叫其getBlockData方法 shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
.asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
new NioManagedBuffer(buffer)
} else {
throw new BlockNotFoundException(blockId.toString)
}
}
}
org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver()方法獲取相應的ShuffleBlockResolver,如果是Hash Shuffle,則
是org.apache.spark.shuffle.FileShuffleBlockResolver,如果是Sort Shuffle則org.apache.spark.shuffle.IndexShuffleBlockResolver。然後呼叫對應ShuffleBlockResolver的getBlockData方法,返回對應的FileSegment。
FileShuffleBlockResolver#getBlockData方法原始碼如下:
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
//對應Hash Shuffle中的Shuffle Consolidate Files機制生成的檔案
if (consolidateShuffleFiles) {
// Search all file groups associated with this shuffle.
val shuffleState = shuffleStates(blockId.shuffleId)
val iter = shuffleState.allFileGroups.iterator
while (iter.hasNext) {
val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
if (segmentOpt.isDefined) {
val segment = segmentOpt.get
return new FileSegmentManagedBuffer(
transportConf, segment.file, segment.offset, segment.length)
}
}
throw new IllegalStateException("Failed to find shuffle block: " + blockId)
} else {
//普通的Hash Shuffle機制生成的檔案
val file = blockManager.diskBlockManager.getFile(blockId)
new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}
}
IndexShuffleBlockResolver#getBlockData方法原始碼如下:
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
// The block is actually going to be a range of a single map output file for this map, so
// find out the consolidated file, then the offset within that from our index
//使用shuffleId和mapId,獲取對應索引檔案
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
val in = new DataInputStream(new FileInputStream(indexFile))
try {
//定位到本次Block對應的資料位置
ByteStreams.skipFully(in, blockId.reduceId * 8)
//資料起始位置
val offset = in.readLong()
//資料結束位置
val nextOffset = in.readLong()
//返回FileSegment
new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
} finally {
in.close()
}
}
ShuffleBlockFetcherIterator#sendRequest()
sendRequest()方法用於從遠端機器上獲取資料
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
//使用ShuffleClient的fetchBlocks方法獲取資料
//有兩種ShuffleClient,分別是ExternalShuffleClient和BlockTransferService
//預設為BlockTransferService
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
}
通過上面的程式碼可以看到,程式碼使用的是shuffleClient.fetchBlocks進行遠端Block資料的獲取,org.apache.spark.network.shuffle.ShuffleClient有兩個子類,分別是ExternalShuffleClient和BlockTransferService,而org.apache.spark.network.shuffle.BlockTransferService又有兩個子類,分別是NettyBlockTransferService和NioBlockTransferService,shuffleClient 物件在 org.apache.spark.storage.BlockManager定義,其原始碼如下:
// org.apache.spark.storage.BlockManager中定義的shuffleClient
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
//使用ExternalShuffleClient獲取遠端Block資料
val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled())
} else {
//使用NettyBlockTransferService或NioBlockTransferService獲取遠端Block資料
blockTransferService
}
程式碼中的blockTransferService在SparkEnv中被初始化,具體如下:
//org.apache.spark.SparkEnv中初始化blockTransferService
val blockTransferService =
conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
case "netty" =>
new NettyBlockTransferService(conf, securityManager, numUsableCores)
case "nio" =>
logWarning("NIO-based block transfer service is deprecated, " +
"and will be removed in Spark 1.6.0.")
new NioBlockTransferService(conf, securityManager)
}