spark中的廣播變數broadcast
Spark中的Broadcast處理
首先先來看一看broadcast的使用程式碼:
val values = List[Int](1,2,3)
val broadcastValues = sparkContext.broadcast(values)
rdd.mapPartitions(iter => {
broadcastValues.getValue.foreach(println)
})
在上面的程式碼中,首先生成了一個集合變數,把這個變數通過sparkContext的broadcast函式進行廣播,
最後在rdd的每一個partition的迭代時,使用這個廣播變數.
接下來看看廣播變數的生成與資料的讀取實現部分
defbroadcast[T: ClassTag](value: T): Broadcast[T] = { assertNotStopped()if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
這裡要注意,使用broadcast時,不能直接對RDD進行broadcast的操作.// This is a warning instead of an exception in order to avoid breaking
// user programs that // might have created RDD broadcast variables but not used them:
通過broadcastManager中的newBroadcast函式來進行廣播.val bc = env.broadcastManager.newBroadcast[T](value, isLocal)val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
在BroadcastManager中生成廣播變數的函式,這個函式直接使用的broadcastFactory的對應函式.
broadcastFactory的例項通過配置spark.broadcast.factory,
預設是TorrentBroadcastFactory.
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {broadcastFactory.newBroadcast[T](value_, isLocal,
nextBroadcastId.getAndIncrement())}
在TorrentBroadcastFactory中生成廣播變數的函式:
在這裡面,直接生成了一個TorrentBroadcast的例項.
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long)
: Broadcast[T] = {new TorrentBroadcast[T](value_, id)}
TorrentBroadcast例項生成時的處理流程:
這裡主要的程式碼部分是直接寫入這個要廣播的變數,返回的值是這個變數所佔用的block的個數.
Broadcast的block的大小通過spark.broadcast.blockSize配置.預設是4MB,
Broadcast的壓縮是否通過spark.broadcast.compress配置,預設是true表示啟用,預設情況下使用snappy的壓縮.
private val broadcastId = BroadcastBlockId(id)/** Total number of blocks this broadcast variable contains. */private val numBlocks: Int = writeBlocks(obj)
接下來生成一個lazy的屬性,這個屬性只有在具體的使用時,才會執行,在例項生成時不執行(上面的示例中的getValue.foreach時執行).
@transient private lazy val _value: T = readBroadcastBlock()
override protected def getValue() = {_value}
看看例項生成時的writeBlocks的函式:
private def writeBlocks(value: T): Int = {
這裡先把這個廣播變數儲存一份到當前的task的storage中,這樣做是保證在讀取時,如果要使用這個廣播變數的task就是本地的task時,直接從blockManager中本地讀取.SparkEnv.get.blockManager.putSingle(broadcastId, value,
StorageLevel.MEMORY_AND_DISK,tellMaster = false)
這裡根據block的設定大小,對value進行序列化/壓縮分塊,每一個塊的大小為blocksize的大小,val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer,
compressionCodec)
這裡把序列化並壓縮分塊後的blocks進行迭代,儲存到blockManager中, blocks.zipWithIndex.foreach { case (block, i) => SparkEnv.get.blockManager.putBytes(BroadcastBlockId(id, "piece" + i),block,StorageLevel.MEMORY_AND_DISK_SER,tellMaster = true) }
這個函式的返回值是一個int型別的值,這個值就是序列化壓縮儲存後block的個數. blocks.length}
在我們的示例中,使用getValue時,會執行例項初始化時定義的lazy的函式readBroadcastBlock:
private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf)
這裡先從local端的blockmanager中直接讀取storage中對應此廣播變數的內容,如果能讀取到,表示這個廣播變數已經讀取過來或者說這個task就是廣播的本地executor. SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {case Some(x) => x.asInstanceOf[T]
下面這部分執行時,表示這個廣播變數在當前的executor中是第一次讀取,通過readBlocks函式去讀取這個廣播變數的所有的blocks,反序列化後,直接把這個廣播變數儲存到本地的blockManager中,下次讀取時,就可以直接從本地進行讀取.case None => logInfo("Started reading broadcast variable " + id)val startTimeMs = System.currentTimeMillis()val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" +
Utils.getUsedTimeMs(startTimeMs))val obj = TorrentBroadcast.unBlockifyObject[T]( blocks, SparkEnv.get.serializer, compressionCodec)// Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it.SparkEnv.get.blockManager.putSingle(broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) obj } }}
最後再看看readBlocks函式的處理流程:
private def readBlocks(): Array[ByteBuffer] = {
這裡定義的變數用於儲存讀取到的block的資訊,numBlocks是廣播變數序列化後所佔用的block的個數.val blocks = new Array[ByteBuffer](numBlocks)val bm = SparkEnv.get.blockManager
這裡開始迭代讀取每一個block的內容,這裡的讀取是先從local中進行讀取,如果local中沒有讀取到資料時,通過blockManager讀取遠端的資料,通過讀取這個block對應的location從這個location去讀取這個block的內容,並存儲到本地的blockManager中.最後,這個函式返回讀取到的blocks的集合.for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {val pieceId = BroadcastBlockId(id, "piece" + pid) logDebug(s"Reading piece $pieceId of $broadcastId")def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>SparkEnv.get.blockManager.putBytes( pieceId,block,StorageLevel.MEMORY_AND_DISK_SER,tellMaster = true) block }val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(throw new SparkException(s"Failed to get $pieceId of $broadcastId")) blocks(pid) = block } blocks}