1. 程式人生 > >spark任務在executor端的執行過程分析

spark任務在executor端的執行過程分析

CoarseGrainedExecutorBackend

上一篇,我們主要分析了一次作業的提交過程,嚴格說是在driver端的過程,作業提交之後經過DAGScheduler根據shuffle依賴關係劃分成多個stage,依次提交每個stage,將每個stage創建於分割槽數相同數量的Task,幷包裝成一個任務集,交給TaskSchedulerImpl進行分配。TaskSchedulerImpl則會根據SchedulerBackEnd提供的計算資源(executor),並考慮任務本地性,黑名單,排程池的排程順序等因素對任務按照round-robin的方式進行分配,並將Task與executor的分配關係包裝成TaskDescription返回給SchedulerBackEnd。然後SchedulerBackEnd就會根據收到的TaskDescription將任務再次序列化之後傳送到對應的executor上執行。本篇,我們就來分析一下Task在executor上的執行過程。

任務執行入口Executor.launchTask

首先,我們知道CoarseGrainedExecutorBackend是yarn模式下的executor的實現類,這時一個rpc服務端,所以我們根據rpc客戶端也就是CoarseGraineSchedulerBackEnd傳送的訊息,然後在服務端找到處理對應訊息的方法,順藤摸瓜就能找到Task執行的入口。通過上一篇的分析知道傳送任務時,CoarseGraineSchedulerBackEnd傳送的是一個LaunchTask型別的訊息,我們看一下CoarseGrainedExecutorBackend.receive方法,其中對於LaunchTask訊息的處理如下:

case LaunchTask(data) =>
  if (executor == null) {
    exitExecutor(1, "Received LaunchTask command but executor was null")
  } else {
    val taskDesc = TaskDescription.decode(data.value)
    logInfo("Got assigned task " + taskDesc.taskId)
    executor.launchTask(this, taskDesc)
  }

可以看到,實際上任務時交給內部的Executor物件來處理,實際上Executor物件承擔了executor端的絕大部分邏輯,可以認為CoarseGrainedExecutorBackend僅僅是充當rpc訊息中轉的角色,充當spark的rpc框架中端點的角色,而實際的任務執行的邏輯則是由Executor物件來完成的。

Executor概述

我們先來看一下Executor類的說明:

/**
 * Spark executor, backed by a threadpool to run tasks.
 *
 * This can be used with Mesos, YARN, and the standalone scheduler.
 * An internal RPC interface is used for communication with the driver,
 * except in the case of Mesos fine-grained mode.
 */

Executor內部有一個執行緒池用來執行任務,Mesos, YARN, 和 standalone模式都是用這個類作為任務執行的邏輯。此外Executor物件持有SparkEnv的引用,以此來使用spark的一些基礎設施,包括rpc引用。
我們還是以任務執行為線索分析這個類的程式碼。

Executor.launchTask

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
  val tr = new TaskRunner(context, taskDescription)
  runningTasks.put(taskDescription.taskId, tr)
  threadPool.execute(tr)
}

這個程式碼沒什麼好說的,應該沒人看不懂吧。所以接下來我們就看一下TaskRunner這個類。
從這個地方也能看出來,在executor端,一個task對應一個執行緒。

TaskRunner.run

這個方法賊長,沒有一點耐心還真不容易看完。
其中有一些統計量我就不說了,比如任務執行時間統計,cpu耗時統計,gc耗時統計等等,這裡有一點可以積累的地方是MXBean,cpu,gc耗時都是通過獲取jvm內建的相關的MXBean獲取到的,入口類是ManagementFactory,具體的可以細看,這裡不再展開。

總結一下這個方法的主要步驟:

  • 首先向driver傳送一個更新任務狀態的訊息,通知driver這個task處於執行的狀態。
  • 設定任務屬性,更新依賴的檔案和jar包,將新的jar包新增到類載入器的尋找路徑中;注意這些資訊都是從driver端跟著TaskDescription一起傳過來的。
  • 對任務進行反序列化生成Task物件,根據任務型別可能是ShuffleMapTask或者ResultTask
  • 檢查任務有沒有被殺死,如果被殺死則跑一個異常;(driver隨時都可能傳送一個殺死任務的訊息)
  • 呼叫Task.run方法執行任務的執行邏輯
  • 任務執行結束後,清除未正常釋放的記憶體資源和block鎖資源,並在需要的時候列印資源洩漏的告警日誌和丟擲異常
  • 再次檢測任務是否被殺死
  • 將任務執行的結果資料序列化
  • 更新一些任務統計量(一些累加器),以及更新度量系統中的相關統計量
  • 收集該任務相關的所有累加器(包括內建的統計量累加器和使用者註冊的累加器)
  • 將累加器資料和任務結果資料封裝成一個物件並在此序列化
  • 檢測序列化後的體積,有兩個閾值:maxResultSize和maxDirectResultSize,如果超過maxResultSize直接丟棄結果,就是不往blockmanager裡面寫資料,這樣driver端在試圖通過blockmanager遠端拉取資料的時候就獲取不到資料,這時driver就知道這個任務的結果資料太大,失敗了;而對於體積超過maxDirectResultSize的情況,會將任務結果資料通過blockmanager寫到本地記憶體和磁碟,然後將block資訊傳送給driver,driver會根據這些資訊來這個節點拉取資料;如果體積小於maxDirectResultSize,則直接通過rpc介面將結果資料傳送給driver。
  • 最後還會有對任務失敗的各種總異常的處理。

    override def run(): Unit = {
    threadId = Thread.currentThread.getId
    Thread.currentThread.setName(threadName)
    // 監控執行緒執行情況的MXBean
    val threadMXBean = ManagementFactory.getThreadMXBean
    // 記憶體管理器
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
    // 記錄反序列化的耗時,回憶一下,我們再spark的UI介面上可以看到這個統計值,原來就是在這裡統計的
    val deserializeStartTime = System.currentTimeMillis()
    // 統計反序列化的cpu耗時
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
    } else 0L
    Thread.currentThread.setContextClassLoader(replClassLoader)
    val ser = env.closureSerializer.newInstance()
    logInfo(s"Running $taskName (TID $taskId)")
    // TODO 通過executor後端向driver傳送一個任務狀態更新的訊息
    execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
    var taskStart: Long = 0
    var taskStartCpu: Long = 0
    // 依然是通過MXBean獲取gc總時長
    startGCTime = computeTotalGcTime()

    try {
      // Must be set before updateDependencies() is called, in case fetching dependencies
      // requires access to properties contained within (e.g. for access control).
      Executor.taskDeserializationProps.set(taskDescription.properties)
    
      // TODO 更新依賴的檔案和jar包,從driver端拉取到本地,並快取下來
      updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
      // 對任務進行反序列化,這裡卻並沒有進行耗時統計
      task = ser.deserialize[Task[Any]](
        taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
      // 屬性集合也是從driver端跟隨taskDescription一起傳送過來的
      task.localProperties = taskDescription.properties
      // 設定記憶體管理器
      task.setTaskMemoryManager(taskMemoryManager)
    
      // If this task has been killed before we deserialized it, let's quit now. Otherwise,
      // continue executing the task.
      // 檢查有沒有被殺掉
      val killReason = reasonIfKilled
      if (killReason.isDefined) {
        // Throw an exception rather than returning, because returning within a try{} block
        // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
        // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
        // for the task.
        throw new TaskKilledException(killReason.get)
      }
    
      // The purpose of updating the epoch here is to invalidate executor map output status cache
      // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
      // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
      // we don't need to make any special calls here.
      //
      if (!isLocal) {
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        // 更新epoch值和map輸出狀態
        env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
      }
    
      // Run the actual task and measure its runtime.
      // 執行任務並統計執行時間
      taskStart = System.currentTimeMillis()
      // 統計當前執行緒的cpu耗時
      taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
      var threwException = true
      val value = try {
        // 呼叫task.run方法執行任務
        val res = task.run(
          // 任務id
          taskAttemptId = taskId,
          // 任務的嘗試次數
          attemptNumber = taskDescription.attemptNumber,
          // 度量系統
          metricsSystem = env.metricsSystem)
        threwException = false
        res
      } finally {
        // 釋放關於該任務的所有鎖, 該任務相關的block的讀寫鎖
        val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
        // 清除所有分配給該任務的記憶體空間
        val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
    
        // 如果threwException為false,說明任務正常執行完成
        // 在任務正常執行完的前提下如果還能夠釋放出記憶體,
        // 說明在任務正常執行的過程中沒有正確地釋放使用的記憶體,也就是發生了記憶體洩漏
        if (freedMemory > 0 && !threwException) {
          val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
          if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logWarning(errMsg)
          }
        }
    
        // 這裡對於鎖資源的檢測和記憶體資源的檢測是相同的邏輯
        // spark作者認為,具體的任務應該自己負責將申請的資源(包括記憶體和鎖資源)在使用完後釋放掉,
        // 不能依賴於靠後面的補救措施
        // 如果沒有正常釋放,就發生了資源洩漏
        // 這裡則是對鎖鎖資源洩漏的檢查
        if (releasedLocks.nonEmpty && !threwException) {
          val errMsg =
            s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
              releasedLocks.mkString("[", ", ", "]")
          if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logInfo(errMsg)
          }
        }
      }
      // 列印拉取異常日誌
      // 程式碼執行到這裡說明使用者並沒有拋拉取異常
      // 但是框架檢測到拉取異常,這說明使用者把拉取異常吞了,這顯然是錯誤的行為,
      // 因此需要列印一條錯誤日誌提醒使用者
      task.context.fetchFailed.foreach { fetchFailure =>
        // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
        // other exceptions.  Its *possible* this is what the user meant to do (though highly
        // unlikely).  So we will log an error and keep going.
        logError(s"TID ${taskId} completed successfully though internally it encountered " +
          s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
          s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
      }
      // 統計任務完成時間
      val taskFinish = System.currentTimeMillis()
      // 統計任務執行緒佔用的cpu時間
      val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
    
      // If the task has been killed, let's fail it.
      // 再次檢測任務是否被殺掉
      task.context.killTaskIfInterrupted()
    
      // 任務結果的序列化
      val resultSer = env.serializer.newInstance()
      val beforeSerialization = System.currentTimeMillis()
      val valueBytes = resultSer.serialize(value)
      val afterSerialization = System.currentTimeMillis()
    
      // Deserialization happens in two parts: first, we deserialize a Task object, which
      // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
      task.metrics.setExecutorDeserializeTime(
        (taskStart - deserializeStartTime) + task.executorDeserializeTime)
      task.metrics.setExecutorDeserializeCpuTime(
        (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
      // We need to subtract Task.run()'s deserialization time to avoid double-counting
      task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
      task.metrics.setExecutorCpuTime(
        (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
      task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
      task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
    
      // Expose task metrics using the Dropwizard metrics system.
      // Update task metrics counters
      executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime)
      executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime)
      executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime)
      executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime)
      executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime)
      executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime)
      executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME
        .inc(task.metrics.shuffleReadMetrics.fetchWaitTime)
      executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime)
      executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.totalBytesRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.remoteBytesRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK
        .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk)
      executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.localBytesRead)
      executorSource.METRIC_SHUFFLE_RECORDS_READ
        .inc(task.metrics.shuffleReadMetrics.recordsRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED
        .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched)
      executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED
        .inc(task.metrics.shuffleReadMetrics.localBlocksFetched)
      executorSource.METRIC_SHUFFLE_BYTES_WRITTEN
        .inc(task.metrics.shuffleWriteMetrics.bytesWritten)
      executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN
        .inc(task.metrics.shuffleWriteMetrics.recordsWritten)
      executorSource.METRIC_INPUT_BYTES_READ
        .inc(task.metrics.inputMetrics.bytesRead)
      executorSource.METRIC_INPUT_RECORDS_READ
        .inc(task.metrics.inputMetrics.recordsRead)
      executorSource.METRIC_OUTPUT_BYTES_WRITTEN
        .inc(task.metrics.outputMetrics.bytesWritten)
      executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
        .inc(task.metrics.inputMetrics.recordsRead)
      executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
      executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
      executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)
    
      // Note: accumulator updates must be collected after TaskMetrics is updated
      // 這裡手機
      val accumUpdates = task.collectAccumulatorUpdates()
      // TODO: do not serialize value twice
      val directResult = new DirectTaskResult(valueBytes, accumUpdates)
      val serializedDirectResult = ser.serialize(directResult)
      val resultSize = serializedDirectResult.limit()
    
      // directSend = sending directly back to the driver
      val serializedResult: ByteBuffer = {
        if (maxResultSize > 0 && resultSize > maxResultSize) {
          logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
            s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
            s"dropping it.")
          ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
        } else if (resultSize > maxDirectResultSize) {
          val blockId = TaskResultBlockId(taskId)
          env.blockManager.putBytes(
            blockId,
            new ChunkedByteBuffer(serializedDirectResult.duplicate()),
            StorageLevel.MEMORY_AND_DISK_SER)
          logInfo(
            s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
          ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
        } else {
          logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
          serializedDirectResult
        }
      }
    
      setTaskFinishedAndClearInterruptStatus()
      execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
    
    } catch {
      case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
        val reason = task.context.fetchFailed.get.toTaskFailedReason
        if (!t.isInstanceOf[FetchFailedException]) {
          // there was a fetch failure in the task, but some user code wrapped that exception
          // and threw something else.  Regardless, we treat it as a fetch failure.
          val fetchFailedCls = classOf[FetchFailedException].getName
          logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
            s"failed, but the ${fetchFailedCls} was hidden by another " +
            s"exception.  Spark is handling this like a fetch failure and ignoring the " +
            s"other exception: $t")
        }
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
    
      case t: TaskKilledException =>
        logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
    
      case _: InterruptedException | NonFatal(_) if
          task != null && task.reasonIfKilled.isDefined =>
        val killReason = task.reasonIfKilled.getOrElse("unknown reason")
        logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(
          taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
    
      case CausedBy(cDE: CommitDeniedException) =>
        val reason = cDE.toTaskCommitDeniedReason
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
    
      case t: Throwable =>
        // Attempt to exit cleanly by informing the driver of our failure.
        // If anything goes wrong (or this was a fatal exception), we will delegate to
        // the default uncaught exception handler, which will terminate the Executor.
        logError(s"Exception in $taskName (TID $taskId)", t)
    
        // SPARK-20904: Do not report failure to driver if if happened during shut down. Because
        // libraries may set up shutdown hooks that race with running tasks during shutdown,
        // spurious failures may occur and can result in improper accounting in the driver (e.g.
        // the task failure would not be ignored if the shutdown happened because of premption,
        // instead of an app issue).
        if (!ShutdownHookManager.inShutdown()) {
          // Collect latest accumulator values to report back to the driver
          val accums: Seq[AccumulatorV2[_, _]] =
            if (task != null) {
              task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
              task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
              task.collectAccumulatorUpdates(taskFailed = true)
            } else {
              Seq.empty
            }
    
          val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
    
          val serializedTaskEndReason = {
            try {
              ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
            } catch {
              case _: NotSerializableException =>
                // t is not serializable so just send the stacktrace
                ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
            }
          }
          setTaskFinishedAndClearInterruptStatus()
          execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
        } else {
          logInfo("Not reporting error to driver during JVM shutdown.")
        }
    
        // Don't forcibly exit unless the exception was inherently fatal, to avoid
        // stopping other tasks unnecessarily.
        if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
          uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
        }
    } finally {
      runningTasks.remove(taskId)
    }

    }

Task.run

final def run(
  taskAttemptId: Long,
  attemptNumber: Int,
  metricsSystem: MetricsSystem): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
context = new TaskContextImpl(
  stageId,
  stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
  partitionId,
  taskAttemptId,
  attemptNumber,
  taskMemoryManager,
  localProperties,
  // 度量系統就是SparkEnv的度量物件
  metricsSystem,
  metrics)
TaskContext.setTaskContext(context)
// 記錄執行任務的執行緒
taskThread = Thread.currentThread()

// 主要是更改TaskContext中的任務殺死原因的標記變數
// 以給執行緒發一次中斷
if (_reasonIfKilled != null) {
  kill(interruptThread = false, _reasonIfKilled)
}

new CallerContext(
  "TASK",
  SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
  appId,
  appAttemptId,
  jobId,
  Option(stageId),
  Option(stageAttemptId),
  Option(taskAttemptId),
  Option(attemptNumber)).setCurrentContext()

try {
  runTask(context)
} catch {
  case e: Throwable =>
    // Catch all errors; run task failure callbacks, and rethrow the exception.
    try {
      context.markTaskFailed(e)
    } catch {
      case t: Throwable =>
        e.addSuppressed(t)
    }
    context.markTaskCompleted(Some(e))
    throw e
} finally {
  try {
    // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
    // one is no-op.
    context.markTaskCompleted(None)
  } finally {
    try {
      Utils.tryLogNonFatalError {
        // Release memory used by this thread for unrolling blocks
        // 釋放記憶體快管理器中該任務使用的記憶體,最終是通過記憶體管理器來釋放的
        // 實際上就是更新記憶體管理器內部的一些用於記錄記憶體使用情況的簿記量
        // 真正的記憶體回收肯定還是有gc來完成的
        SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
        SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
          MemoryMode.OFF_HEAP)
        // Notify any tasks waiting for execution memory to be freed to wake up and try to
        // acquire memory again. This makes impossible the scenario where a task sleeps forever
        // because there are no other tasks left to notify it. Since this is safe to do but may
        // not be strictly necessary, we should revisit whether we can remove this in the
        // future.
        val memoryManager = SparkEnv.get.memoryManager
        // 記憶體釋放之後,需要通知其他在等待記憶體資源的 執行緒
        memoryManager.synchronized { memoryManager.notifyAll() }
      }
    } finally {
      // Though we unset the ThreadLocal here, the context member variable itself is still
      // queried directly in the TaskRunner to check for FetchFailedExceptions.
      TaskContext.unset()
    }
  }
}
}
  • 建立一個TaskContextImpl,並設定到一個ThreadLocal變數中
  • 檢查任務是否被殺死
  • 呼叫runTask方法執行實際的任務邏輯
  • 最後會釋放在shuffle過程中申請的用於資料unroll的記憶體資源

所以,接下來我們要分析的肯定就是runTask方法,而這個方法是個抽象方法,由於ResultTask很簡單,我就不再分析了,這裡我重點分析一下ShuffleMapTask。

ShuffleMapTask.runTask

override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
  threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
// 反序列化RDD和shuffle, 關鍵的步驟
// 這裡思考rdd和shuffle反序列化時,內部的SparkContext物件是怎麼反序列化的
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
  ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
  threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L

var writer: ShuffleWriter[Any, Any] = null
try {
  // shuffle管理器
  val manager = SparkEnv.get.shuffleManager
  // 獲取一個shuffle寫入器
  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
  // 這裡可以看到rdd計算的核心方法就是iterator方法
  // SortShuffleWriter的write方法可以分為幾個步驟:
  // 將上游rdd計算出的資料(通過呼叫rdd.iterator方法)寫入記憶體緩衝區,
  // 在寫的過程中如果超過 記憶體閾值就會溢寫磁碟檔案,可能會寫多個檔案
  // 最後將溢寫的檔案和記憶體中剩餘的資料一起進行歸併排序後寫入到磁碟中形成一個大的資料檔案
  // 這個排序是先按分割槽排序,在按key排序
  // 在最後歸併排序後寫的過程中,沒寫一個分割槽就會手動刷寫一遍,並記錄下這個分割槽資料在檔案中的位移
  // 所以實際上最後寫完一個task的資料後,磁碟上會有兩個檔案:資料檔案和記錄每個reduce端partition資料位移的索引檔案
  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  // 主要是刪除中間過程的溢寫檔案,向記憶體管理器釋放申請的記憶體
  writer.stop(success = true).get
} catch {
  case e: Exception =>
    try {
      if (writer != null) {
        writer.stop(success = false)
      }
    } catch {
      case e: Exception =>
        log.debug("Could not stop writer", e)
    }
    throw e
}
}

這個方法還是大概邏輯還是很簡單的,主要就是通過rdd的iterator方法獲取當前task對應的分割槽的計算結果(結果一一個迭代器的形式返回)利用shuffleManager通過blockManager寫入到檔案block中,然後將block資訊傳回driver上報給BlockManagerMaster。
所以實際上重要的步驟有兩個:通過RDD的計算鏈獲取計算結果;將計算結果經過排序和分割槽寫到檔案中。
這裡我先分析第二個步驟。

SortShuffleWriter.write

spark在2.0之後shuffle管理器改成了排序shuffle管理器,即SortShuffleManager,所以這裡通過SortShuffleManager管理器獲取到的在一般情況下都是SortShuffleWriter,當然在滿足bypass條件(map端不需要combine,並且分割槽數小於200)的情況下會使用BypassMergeSortShuffleWriter。

override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
  // map端進行合併的情況,此時使用者應該提供聚合器和順序
  require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
  new ExternalSorter[K, V, C](
    context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
  // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
  // care whether the keys get sorted in each partition; that will be done on the reduce side
  // if the operation being run is sortByKey.
  new ExternalSorter[K, V, V](
    context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
// 將map資料全部寫入排序器中,
// 這個過程中可能會生成多個溢寫檔案
sorter.insertAll(records)

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
// mapId就是shuffleMap端RDD的partitionId
// 獲取這個map分割槽的shuffle輸出檔名
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
// 加一個uuid字尾
val tmp = Utils.tempFileWith(output)
try {
  val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
  // 這一步將溢寫到的磁碟的檔案和記憶體中的資料進行歸併排序,
  // 並溢寫到一個檔案中,這一步寫的檔案是臨時檔名
  val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
  // 這一步主要是寫入索引檔案,使用move方法原子第將臨時索引和臨時資料檔案重新命名為正常的檔名
  shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
  // 返回一個狀態物件,包含shuffle服務Id和各個分割槽資料在檔案中的位移
  mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
  if (tmp.exists() && !tmp.delete()) {
    logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
  }
}
}

總結一下這個方法的主要邏輯:

  • 首先獲取一個排序器,並檢查是否有map端的合併器
  • 將rdd計算結果資料寫入排序器,過程中可能會溢寫過個磁碟檔案
  • 最後將多個碎小的溢寫檔案和記憶體緩衝區的資料進行歸併排序,寫到一個檔案中
  • 將每個分割槽資料在檔案中的偏移量寫到一個索引檔案中,用於reduce階段拉取資料時使用
  • 返回一個MapStatus物件,封裝了當前executor上的blockManager的id和每個分割槽在資料檔案中的位移量

總結

本篇先分析到這裡。剩下的程式碼都是屬於排序器內部的對資料的排序和溢寫檔案的邏輯。這部分內容值得寫一篇文章來單獨分析。
總結一下任務在executor端的執行流程:

  • 首先executor端的rpc服務端點收到LaunchTask的訊息,並對傳過來的任務資料進行反序列化成TaskDescription
  • 將任務交給Executor物件執行
  • Executor根據傳過來的TaskDescription物件建立一個TaskRunner物件,並放到執行緒池中執行。這裡的執行緒池用的是Executors.newCachedThreadPool,空閒是不會有執行緒在跑
  • TaskRunner對任務進一步反序列化,呼叫Task.run方法執行任務執行邏輯
  • ShuffleMapTask型別的任務會將rdd計算結果資料經過排序合併之後寫到一個檔案中,並寫一個索引檔案
  • 任務執行完成後會更新一些任務統計量和度量系統中的一些統計量
  • 最後會根據結果序列化後的大小選擇不同的方式將結果傳回driver。