1. 程式人生 > >[spark streaming] 狀態管理 updateStateByKey&mapWithState

[spark streaming] 狀態管理 updateStateByKey&mapWithState

前言

SparkStreaming 7*24 小時不間斷的執行,有時需要管理一些狀態,比如wordCount,每個batch的資料不是獨立的而是需要累加的,這時就需要sparkStreaming來維護一些狀態,目前有兩種方案updateStateByKey&mapWithState,mapWithState是spark1.6新加入的儲存狀態的方案,官方聲稱有10倍效能提升。

updateStateByKey

先上一個示例:

def updateFunction(currValues:Seq[Int],preValue:Option[Int]): Option[Int] = {
       val currValueSum = currValues.sum
        //上面的Int型別都可以用物件型別替換
        Some(currValueSum + preValue.getOrElse(0
)) //當前值的和加上歷史值 } kafkaStream.map(r => (r._2,1)).updateStateByKey(updateFunction _)

這裡的updateFunction方法就是需要我們自己去實現的狀態跟新的邏輯,currValues就是當前批次的所有值,preValue是歷史維護的狀態,updateStateByKey返回的是包含歷史所有狀態資訊的DStream,下面我們來看底層是怎麼實現狀態的管理的,通過跟蹤原始碼看到最核心的實現方法:

  private [this] def computeUsingPreviousRDD(
      batchTime
: Time, parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = { // Define the function for the mapPartition operation on cogrouped RDD; // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc val finalFunc
= (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) =>
{ val i = iterator.map { t => val itr = t._2._2.iterator val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) } updateFuncLocal(batchTime, i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) Some(stateRDD) }

可以看到是將parentRDDpreStateRDD進行co-group,然後將finalFunc方法作用於每個Partition,看到finalFunc方法的實現裡面(t._1, t._2._1.toSeq, headOption)這樣的形式,(key,currValues,preValue)這不就是和我們需要自己實現的updateFun類似的結構嗎,是的沒錯,我們的方法已經被包裝了一次:

def updateStateByKey[S: ClassTag](
      updateFunc: (Seq[V], Option[S]) => Option[S],
      partitioner: Partitioner
    ): DStream[(K, S)] = ssc.withScope {
    val cleanedUpdateF = sparkContext.clean(updateFunc)
    val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
      iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s)))
    }
    updateStateByKey(newUpdateFunc, partitioner, true)
  }

可以知道每次呼叫updateStateByKey都會將舊的狀態RDD和當前batch的RDD進行co-group來得到一個新的狀態RDD,即使真正需要跟新的資料只有1條也需要將兩個RDD進行cogroup,所有的資料都會被計算一遍,而且隨著狀態的不斷增加,執行速度會越來越慢。

為了解決這一問題,mapWithState應運而生。

mapWithState

先來個示例:

   val initialRDD = ssc.sparkContext.parallelize(List[(String, Int)]())
    //自定義mappingFunction,累加單詞出現的次數並更新狀態
    val mappingFunc = (word: String, count: Option[Int], state: State[Int]) => {
      val sum = count.getOrElse(0) + state.getOption.getOrElse(0)
      val output = (word, sum)
      state.update(sum)
      output
    }
    //呼叫mapWithState進行管理流資料的狀態
    kafkaStream.map(r => (r._2,1)).mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD)).print()

這裡的initialRDD就是初始化狀態,updateStateByKey也有對應的API。這裡的mappingFun也是需要我們自己實現的狀態跟新邏輯,呼叫state.update()就是對狀態的跟新,output就是通過mapWithState後返回的DStream中的資料形式。注意這裡不是直接傳入的mappingFunc函式,而是一個StateSpec 的物件,其實也是對函式的一個包裝而已。接下來我們跟蹤原始碼看看是怎麼實現狀態的管理的,會建立一個MapWithStateDStreamImpl例項:

def mapWithState[StateType: ClassTag, MappedType: ClassTag](
      spec: StateSpec[K, V, StateType, MappedType]
    ): MapWithStateDStream[K, V, StateType, MappedType] = {
    new MapWithStateDStreamImpl[K, V, StateType, MappedType](
      self,
      spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
    )
  }

當然是要看看其compute方法是怎麼實現的:

 private val internalStream =
    new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)

  override def compute(validTime: Time): Option[RDD[MappedType]] = {
    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
  }

compute方法又把處理邏輯給了internalStream:InternalMapWithStateDStream,繼續看InternalMapWithStateDStream的compute方法主要處理邏輯:

override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
    // Get the previous state or create a new empty state RDD
    val prevStateRDD = getOrCompute(validTime - slideDuration) match {
      case Some(rdd) =>
        if (rdd.partitioner != Some(partitioner)) {
          // If the RDD is not partitioned the right way, let us repartition it using the
          // partition index as the key. This is to ensure that state RDD is always partitioned
          // before creating another state RDD using it
          MapWithStateRDD.createFromRDD[K, V, S, E](
            rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
        } else {
          rdd
        }
      case None =>
        MapWithStateRDD.createFromPairRDD[K, V, S, E](
          spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
          partitioner,
          validTime
        )
    }

    // Compute the new state RDD with previous state RDD and partitioned data RDD
    // Even if there is no data RDD, use an empty one to create a new state RDD
    val dataRDD = parent.getOrCompute(validTime).getOrElse {
      context.sparkContext.emptyRDD[(K, V)]
    }
    val partitionedDataRDD = dataRDD.partitionBy(partitioner)
    val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
      (validTime - interval).milliseconds
    }
    Some(new MapWithStateRDD(
      prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
  }

先後獲取prevStateRDDparentRDD,並且保證使用的是同樣的partitioner,接著以兩個rdd為引數、自定義的mappingFunction函式、以及key的超時時間等為引數又建立了MapWithStateRDD,該RDD繼承了RDD[MapWithStateRDDRecord[K, S, E]]MapWithStateRDD中的資料都是MapWithStateRDDRecord物件,每個分割槽對應一個物件來儲存狀態(這就是為什麼兩個RDD需要用同一個Partitioner),看看MapWithStateRDD的compute方法:

 override def compute(
      partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {

    val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
    val prevStateRDDIterator = prevStateRDD.iterator(
      stateRDDPartition.previousSessionRDDPartition, context)
    val dataIterator = partitionedDataRDD.iterator(
      stateRDDPartition.partitionedDataRDDPartition, context)

    val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
    val newRecord = MapWithStateRDDRecord.updateRecordWithData(
      prevRecord,
      dataIterator,
      mappingFunction,
      batchTime,
      timeoutThresholdTime,
      removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
    )
    Iterator(newRecord)
  }

拿到prevStateRDDparentRDD對應分割槽的迭代器,接著獲取了prevStateRDD的一條資料,這個分割槽也只有一條MapWithStateRDDRecord型別的資料,維護了對應分割槽所有資料狀態,接著呼叫了最核心的方法來跟新狀態,最後返回了只包含一條資料的迭代器,我們來看看是怎麼這個核心的計算邏輯:

 def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
    prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
    dataIterator: Iterator[(K, V)],
    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
    batchTime: Time,
    timeoutThresholdTime: Option[Long],
    removeTimedoutData: Boolean
  ): MapWithStateRDDRecord[K, S, E] = {
    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
    val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

    val mappedData = new ArrayBuffer[E]
    val wrappedState = new StateImpl[S]()

    // Call the mapping function on each record in the data iterator, and accordingly
    // update the states touched, and collect the data returned by the mapping function
    dataIterator.foreach { case (key, value) =>
      wrappedState.wrap(newStateMap.get(key))
      val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
      if (wrappedState.isRemoved) {
        newStateMap.remove(key)
      } else if (wrappedState.isUpdated
          || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
      }
      mappedData ++= returned
    }

    // Get the timed out state records, call the mapping function on each and collect the
    // data returned
    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
        wrappedState.wrapTimingOutState(state)
        val returned = mappingFunction(batchTime, key, None, wrappedState)
        mappedData ++= returned
        newStateMap.remove(key)
      }
    }

    MapWithStateRDDRecord(newStateMap, mappedData)
  }

先copy了原來的狀態,接著定義了兩個變數,mappedData是最終要返回的結果,wrappedState可以看成是對state的包裝,添加了一些額外的方法。

接著遍歷當前批次的資料,從狀態中取出key對應的原來的state,並根據自定義的函式來對state進行跟新,這裡涉及到state的remove&update&timeout來對newStateMap進行跟新操作,並將有跟新的狀態加入到了mappedData中。

若有設定超時時間,則還會對超時了的key進行移除,也會加入到mappedData中,最終通過新的狀態物件newStateMap和需返回的mappedData陣列構建了MapWithStateRDDRecord物件來返回。

而在前面提到的MapWithStateDStreamImpl例項的compute方法中:

  override def compute(validTime: Time): Option[RDD[MappedType]] = {
    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
  }

呼叫的就是這個mappedData資料。

我們發現返回的都是有update的資料,若要獲取所有的狀態在mapWithState之後呼叫stateSnapshots即可。若要清除某個key的狀態,可在自定義的方法中呼叫state.remove()

總結

  • updateStateByKey底層是將preSateRDD和parentRDD進行co-group,然後對所有資料都將經過自定義的mapFun函式進行一次計算,即使當前batch只有一條資料也會進行這麼複雜的計算,大大的降低了效能,並且計算時間會隨著維護的狀態的增加而增加。
  • mapWithstate底層是建立了一個MapWithStateRDD,存的資料是MapWithStateRDDRecord物件,一個Partition對應一個MapWithStateRDDRecord物件,該物件記錄了對應Partition所有的狀態,每次只會對當前batch有的資料進行跟新,而不會像updateStateByKey一樣對所有資料計算。