在spark mapPartition中使用迭代器進行優化
阿新 • • 發佈:2021-12-13
一般在使用mapPartition時,往往會跟隨著檔案的建立或者資料庫的連線等,此時我們需要在建立一個容器,用於儲存維表關聯後的資料,但這有一個問題,建立的容器會佔用記憶體的,這時我們可以使用迭代器進行優化。
一、普遍方法
package org.shydow import java.sql.{Connection, PreparedStatement} import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} import org.shydow.DBPool.MySQLPoolManagerimport scala.collection.mutable.ListBuffer /** * @author shydow * @date 2021-12-13 * @desc mapPartition一般使用方法 */ object TestMapPartition { case class Event(eventId: String, eventName: String, pv: Long, stayTime: String) def main(args: Array[String]): Unit = { val conf: SparkConf = new SparkConf().setAppName("test-mapPartition").setMaster("local[*]") val sc= new SparkContext(conf) val lines: RDD[String] = sc.textFile("/app/event_log.txt", 4) lines.mapPartitions { it => val conn: Connection = MySQLPoolManager.getMySQLManager.getConnection val ps: PreparedStatement = conn.prepareStatement("select event_name from dim_event_info where event_id = ?") val list: ListBuffer[Event]= ListBuffer[Event]() while (it.hasNext) { val line: String = it.next() val arr: Array[String] = line.split(",") ps.setString(1, arr(0)) val res = ps.executEQuery() var eventName: String = null while(res.next){ eventName = res.getString("event_name") } list.append(Event(arr(0), eventName, arr(2).toLong, arr(3))) } list.toIterator } sc.stop() } }
二、使用迭代器進行優化
package org.shydow import java.sql.{Connection, PreparedStatement, ResultSet} import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} import org.shydow.DBPool.MySQLPoolManager /** * @author shydow * @date 2021-12-13 * @desc 測試mapPartition中進行維表關聯時使用迭代器進行優化 */ object TestMapPartition { def main(args: Array[String]): Unit = { val conf: SparkConf = new SparkConf().setAppName("test-mapPartition").setMaster("local[*]") val sc = new SparkContext(conf) val lines: RDD[String] = sc.textFile("/app/event_log.txt", 4) lines.mapPartitions(new LookupEventIter(_)) sc.stop() } } case class Event(eventId: String, eventName: String, pv: Long, stayTime: String) class LookupEventIter(it: Iterator[String]) extends Iterator[Event] { private val conn: Connection = MySQLPoolManager.getMySQLManager.getConnection private val ps: PreparedStatement = conn.prepareStatement("select event_name from dim_event_info where event_id = ?") override def hasNext: Boolean = { if (it.hasNext) true else { ps.close() conn.close() false } } override def next(): Event = { val line: String = it.next() val arr: Array[String] = line.split(",") var eventName: String = null ps.setString(1, arr(0)) val res: ResultSet = ps.executeQuery() while (res.next()) { eventName = res.getString("event_name") } Event(arr(0), eventName, arr(2).toLong, arr(3)) } }
三、資料庫連線池,使用cpd3
package org.shydow.DBPool import java.sql.Connection /** * @author shydow * @date 2021-10-09 */ class MySQLPool extends Serializable { private val cpd = new ComboPooledDataSource(true) try { cpd.setJdbcUrl(Constants.MYSQL_URL) cpd.setDriverClass(Constants.MYSQL_DRIVER) cpd.setUser(Constants.MYSQL_USER) cpd.setPassword(Constants.MYSQL_PASSWORD) cpd.setAcquireIncrement(Constants.MYSQL_AC) cpd.setMinPoolSize(Constants.MYSQL_MINPS) cpd.setMaxPoolSize(Constants.MYSQL_MAXPS) cpd.setMaxStatements(Constants.MYSQL_MAXST) } catch { case e: Exception => e.printStackTrace() } def getConnection: Connection = { try { cpd.getConnection() } catch { case e: Exception => e.printStackTrace() null } } def close(): Unit = { try { cpd.close() } catch { case e: Exception => e.printStackTrace() } } }
package org.shydow.DBPool /** * @author shydow * @date 2021-10-09 */ object MySQLPoolManager { var mm: MySQLPool = _ def getMySQLManager: MySQLPool = { synchronized { if (mm == null) { mm = new MySQLPool } } mm } }