spark-streaming系列------- 3. Kafka DirectDStream方式資料的接收

def createDirectStream[
    K: ClassTag,
    V: ClassTag,
    KD <: Decoder[K]: ClassTag,
    VD <: Decoder[V]: ClassTag] (
      ssc: StreamingContext,
      kafkaParams: Map[String, String],
      topics: Set[String]
  ): InputDStream[(K, V)] = {
    val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
    val kc = new KafkaCluster(kafkaParams)
    val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)

    val result = for {
      * 通過跟Kafka叢集通訊,獲得Kafka某個topic的partition資訊,topicPartitions是一個數組,陣列大小跟Kafka topic的分割槽個數相同
      * 陣列元素包含話題名和parition的index
      * */
      topicPartitions <- kc.getPartitions(topics).right
      leaderOffsets <- (if (reset == Some("smallest")) {
      } else {
    } yield {
        //計算Kafka topic的每個partition的offset
      val fromOffsets = leaderOffsets.map { case (tp, lo) =>
          (tp, lo.offset)
      new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
        ssc, kafkaParams, fromOffsets, messageHandler)

    在這裡,通過跟Kafka叢集通訊,獲得Kafka topic每個partition的訊息偏移量,作為引數繼續建立DirectKafkaInputDstream.


class DirectKafkaInputDStream[
  K: ClassTag,
  V: ClassTag,
  U <: Decoder[K]: ClassTag,
  T <: Decoder[V]: ClassTag,
  R: ClassTag](
    @transient ssc_ : StreamingContext,
    val kafkaParams: Map[String, String],
    val fromOffsets: Map[TopicAndPartition, Long],
    messageHandler: MessageAndMetadata[K, V] => R
) extends InputDStream[R](ssc_) with Logging {
  val maxRetries = context.sparkContext.getConf.getInt(
    "spark.streaming.kafka.maxRetries", 1)

  // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]")
  private[streaming] override def name: String = s"Kafka direct stream [$id]"

  protected[streaming] override val checkpointData =
    new DirectKafkaInputDStreamCheckpointData

  protected val kc = new KafkaCluster(kafkaParams)

  protected val maxMessagesPerPartition: Option[Long] = {
    val ratePerSec = context.sparkContext.getConf.getInt(
      "spark.streaming.kafka.maxRatePerPartition", 0)
    if (ratePerSec > 0) {
      val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
      Some((secsPerBatch * ratePerSec).toLong)
    } else {
  protected var currentOffsets = fromOffsets

  protected final def latestLeaderOffsets(retries: Int): Map[TopicAndPartition, LeaderOffset] = {
    val o = kc.getLatestLeaderOffsets(currentOffsets.keySet)
    // Either.fold would confuse @tailrec, do it manually
    if (o.isLeft) {
      val err = o.left.get.toString
      if (retries <= 0) {
        throw new SparkException(err)
      } else {
        latestLeaderOffsets(retries - 1)
    } else {

  // limits the maximum number of messages per partition
  * 當沒有設定最大接收速率的時候,接收終止點是當前時間的每個partition的offset
  * */
  protected def clamp(
    leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = {
    maxMessagesPerPartition.map { mmp =>
      leaderOffsets.map { case (tp, lo) =>
        tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset))

  override def compute(validTime: Time): Option[KafkaRDD[K, V, U, T, R]] = {
    val untilOffsets = clamp(latestLeaderOffsets(maxRetries))
    val rdd = KafkaRDD[K, V, U, T, R](
      context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler)

    // Report the record number of this batch interval to InputInfoTracker.
    val inputInfo = InputInfo(id, rdd.count)
    ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)

    currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset)

結論:spark-streaming DirectDStream資料接受方式,如果沒有設定最大接收速率,每個batch的資料接收量為一個batch時間間隔內,Kafka topic接收到的訊息量

Kafka的分割槽資訊在DirectKafkaInputDStream的類初始化操作中,通過fromOffsets引數傳遞給它的currentOffsets成員,這個成員在建立KafkaRDD的時候作為初始化成員將Kafka的分割槽資訊傳遞給KafkaRDD,作為生成KafkaRDD paritition的依據。

object KafkaRDD {
  import KafkaCluster.LeaderOffset

   * @param kafkaParams Kafka <a href="http://kafka.apache.org/documentation.html#configuration">
   * configuration parameters</a>.
   *   Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
   *   NOT zookeeper servers, specified in host1:port1,host2:port2 form.
   * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive)
   *  starting point of the batch
   * @param untilOffsets per-topic/partition Kafka offsets defining the (exclusive)
   *  ending point of the batch
   * @param messageHandler function for translating each message into the desired type
  def apply[
    K: ClassTag,
    V: ClassTag,
    U <: Decoder[_]: ClassTag,
    T <: Decoder[_]: ClassTag,
    R: ClassTag](
      sc: SparkContext,
      kafkaParams: Map[String, String],
      fromOffsets: Map[TopicAndPartition, Long],
      untilOffsets: Map[TopicAndPartition, LeaderOffset],
      messageHandler: MessageAndMetadata[K, V] => R
    ): KafkaRDD[K, V, U, T, R] = {
    val leaders = untilOffsets.map { case (tp, lo) =>
        tp -> (lo.host, lo.port)
    //根據Kafka topic的每個partition的起始地址和終止地址計算表示接收資料的資料結構OffsetRange
    val offsetRanges = fromOffsets.map { case (tp, fo) =>
        val uo = untilOffsets(tp)
        OffsetRange(tp.topic, tp.partition, fo, uo.offset)

    new KafkaRDD[K, V, U, T, R](sc, kafkaParams, offsetRanges, leaders, messageHandler)

class KafkaRDD[
  K: ClassTag,
  V: ClassTag,
  U <: Decoder[_]: ClassTag,
  T <: Decoder[_]: ClassTag,
  R: ClassTag] private[spark] (
    sc: SparkContext,
    kafkaParams: Map[String, String],
    val offsetRanges: Array[OffsetRange],
    leaders: Map[TopicAndPartition, (String, Int)],
    messageHandler: MessageAndMetadata[K, V] => R
  ) extends RDD[R](sc, Nil) with Logging with HasOffsetRanges {
  override def getPartitions: Array[Partition] = {
    offsetRanges.zipWithIndex.map { case (o, i) =>
        val (host, port) = leaders(TopicAndPartition(o.topic, o.partition))//host是Kafka broker的ip地址, port是Kafka broker的埠號
        new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset, host, port)

在建立RDD的時候,會最終呼叫到getPartitions方法,這樣確定了KafkaRDD每個partition所在的IP地址和埠號,KafkaRDD每個Paritition所在的IP地址為Kafka broker的地址從前面的文章: 知道,DirectKafkaInputDStream.compute方法被Spark-streaming的排程模組週期呼叫產生DStream的RDD



Spark-streaming任務啟動之後,呼叫了SparkContext.runJob將資料接收和處理任務提交到Spark的Task排程系統。Spark的Task排程系統經過一系列的RDD依賴運算之後找到Root RDD是KafkaRDD。然後根據KafkaRDD的partition首先將KafkaRDD的處理任務新增到任務等待HashMap。實現程式碼在TaskSetManager.addPendingTask方法

private def addPendingTask(index: Int, readding: Boolean = false) {
    // Utility method that adds `index` to a list only if readding=false or it's not already there
    def addTo(list: ArrayBuffer[Int]) {
      if (!readding || !list.contains(index)) {
        list += index

    for (loc <- tasks(index).preferredLocations) {//preferredLocation方法返回partition所在的IP地址
      loc match {
        case e: ExecutorCacheTaskLocation =>
          addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer))
        case e: HDFSCacheTaskLocation => {
          val exe = sched.getExecutorsAliveOnHost(loc.host)
          exe match {
            case Some(set) => {
              for (e <- set) {
                addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer))
              logInfo(s"Pending task $index has a cached location at ${e.host} " +
                ", where there are executors " + set.mkString(","))
            case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
                ", but there are no executors alive there.")
        case _ => Unit
      addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))//由於DirectDStream方式的loc.host地址不屬於Spark叢集和HDFS叢集,所以Task加到了這個HashMap
      for (rack <- sched.getRackForHost(loc.host)) {
        addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))

    if (tasks(index).preferredLocations == Nil) {

    if (!readding) {
      allPendingTasks += index  // No point scanning this whole list to find the old task there  所有的Task都會加入到這個HashMap,包括DirectDStream情況下的Task




  private def resourceOfferSingleTaskSet(
      taskSet: TaskSetManager,
      maxLocality: TaskLocality,
      shuffledOffers: Seq[WorkerOffer],
      availableCpus: Array[Int],
      tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
    var launchedTask = false
    for (i <- 0 until shuffledOffers.size) {
      val execId = shuffledOffers(i).executorId
      val host = shuffledOffers(i).host
      if (availableCpus(i) >= CPUS_PER_TASK) {//按照cpu cores個數分配task
        try {
          for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
            tasks(i) += task //將這個task放在了第i個worker(worker順序已經shuffle了)
            val tid = task.taskId
            taskIdToTaskSetId(tid) = taskSet.taskSet.id//記錄task所在的taskset
            taskIdToExecutorId(tid) = execId//記錄task所在的executor
            executorsByHost(host) += execId
            availableCpus(i) -= CPUS_PER_TASK
            assert(availableCpus(i) >= 0)
            launchedTask = true
        } catch {
          case e: TaskNotSerializableException =>
            logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
            // Do not offer resources for this task, but don't throw an error to allow other
            // task sets to be submitted.
            return launchedTask
    return launchedTask




def resourceOffer(
      execId: String,
      host: String,
      maxLocality: TaskLocality.TaskLocality)
    : Option[TaskDescription] =
    if (!isZombie) {
      val curTime = clock.getTimeMillis()

      var allowedLocality = maxLocality

      if (maxLocality != TaskLocality.NO_PREF) {
        allowedLocality = getAllowedLocalityLevel(curTime)
        if (allowedLocality > maxLocality) {
          // We're not allowed to search for farther-away tasks
          allowedLocality = maxLocality

      dequeueTask(execId, host, allowedLocality) match {
        case Some((index, taskLocality, speculative)) => {
          // Found a task; do some bookkeeping and return a task description
          val task = tasks(index)
          val taskId = sched.newTaskId()
          // Do various bookkeeping
          copiesRunning(index) += 1
          val attemptNum = taskAttempts(index).size
          val info = new TaskInfo(taskId, index, attemptNum, curTime,
            execId, host, taskLocality, speculative)
          taskInfos(taskId) = info
          taskAttempts(index) = info :: taskAttempts(index)
          // Update our locality level for delay scheduling
          // NO_PREF will not affect the variables related to delay scheduling
          if (maxLocality != TaskLocality.NO_PREF) {
            currentLocalityIndex = getLocalityIndex(taskLocality)
            lastLaunchTime = curTime
          // Serialize and return the task
          val startTime = clock.getTimeMillis()
          val serializedTask: ByteBuffer = try {
            Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
          } catch {
            // If the task cannot be serialized, then there's no point to re-attempt the task,
            // as it will always fail. So just abort the whole task-set.
            case NonFatal(e) =>
              val msg = s"Failed to serialize task $taskId, not attempting to retry it."
              logError(msg, e)
              abort(s"$msg Exception during serialization: $e")
              throw new TaskNotSerializableException(e)
          if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
              !emittedTaskSizeWarning) {
            emittedTaskSizeWarning = true
            logWarning(s"Stage ${task.stageId} contains a task of very large size " +
              s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
              s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")

          // We used to log the time it takes to serialize the task, but task size is already
          // a good proxy to task serialization time.
          // val timeTaken = clock.getTime() - startTime
          val taskName = s"task ${info.id} in stage ${taskSet.id}"
          logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format(
              taskName, taskId, host, taskLocality, serializedTask.limit))

          sched.dagScheduler.taskStarted(task, info)
          return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
            taskName, index, serializedTask))
        case _ =>

  private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value)
    : Option[(Int, TaskLocality.Value, Boolean)] =
    for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) {
      return Some((index, TaskLocality.PROCESS_LOCAL, false))

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {//由於KafkaRDD partition所在的Ip地址跟Executor的IP地址不同,所以Task不能從這個HashMap獲取
      for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) {
        return Some((index, TaskLocality.NODE_LOCAL, false))

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) {
      // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
      for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) {
        return Some((index, TaskLocality.PROCESS_LOCAL, false))

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) {
      for {
        rack <- sched.getRackForHost(host)
        index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack))
      } {
        return Some((index, TaskLocality.RACK_LOCAL, false))

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) {//KafkaRDD的處理Task從addPendingTasks這個HashMap獲取
      for (index <- dequeueTaskFromList(execId, allPendingTasks)) {
        return Some((index, TaskLocality.ANY, false))

    // find a speculative task if all others tasks have been scheduled
    dequeueSpeculativeTask(execId, host, maxLocality).map {
      case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}

在產生任務的時候,儘量優先產生本地性高的任務,由於KafkaRDD各個Partition所在的IP地址跟Spark Executor的IP地址不同,只能從allPendingTask這個HashMap獲取任務了。



override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])//rdd.iterator讀取並處理資料,把處理結果返回
      return writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        throw e

這個方法根據RDD的依賴關係,呼叫到了KafkaRDD.compute方法,由於KafkaRDD是root RDD,所以KafkaRDD.compute在一系列依賴RDD中最先執行,返回從Kafka broker接收到的訊息的Iterator ,而Spark在處理RDD partition的時候,RDD paritition中的資料最原始的組織形式就是Iterator

結論:Spark-streaming 採用DirectDStream接收資料,把接收過來的資料直接組織成RDD進行處理