1. 程式人生 > >Spark操作資料庫API

Spark操作資料庫API

寫入Mysql資料庫

import java.sql.{Connection, Date, DriverManager, PreparedStatement}
import org.apache.spark.{SparkConf, SparkContext}

object IPLocation {

  val data2MySQL = (iterator: Iterator[(String, Int)] )=> {

    var conn: Connection = null
    var ps: PreparedStatement = null
    val sql = " INSERT INTO location_info (location, counts, accesse_data) VALUE (?, ?, ?)"
    try{
      conn = DriverManager.getConnection(" jdbc:mysql://localhost:3306/home/hadoop", "hadoop", "000000fzx")
      iterator.foreach(line => {
        ps = conn.prepareStatement(sql)
        ps.setString(1, line._1)
        ps.setInt(2, line._2)
        ps.setDate(3, new Date(System.currentTimeMillis()))
        ps.executeLargeUpdate()
      })
    } catch {
      case e: Exception => println(" Mysql Exception")
    } finally {
      if (ps != null)
        ps.close()
      if (conn != null)
        conn.close()
    }
  }

  // IP--> Long 10
  def  ip2Long(ip: String): Long = {
    val fragments = ip.split("[.]")
    var ipNum = 0L
    for (i<-0 until fragments.length) {
      ipNum = fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }

  // BinarySearch
  def binarySearch(lines:Array[(String, String, String)], ip:Long) : Int = {

    var low = 0
    var high = lines.length -1
    while (low<=high){
      val middle = (low + high) / 2
      if ((ip >= lines(middle)._1.toLong) && (ip <= lines(middle)._2.toLong))
        return middle
      if (ip < lines(middle)._1.toLong)
        high = middle - 1
      else {
        low = middle + 1
      }
    }
    -1
  }

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("IPLocation").setMaster()
    val sc = new SparkContext(conf)
    val ipRulesRdd = sc.textFile("   ").map( line => {

      val fields = line.split("\\|")
      val start_num = fields(2)
      val end_num = fields(3)
      val province = fields(6)
      (start_num, end_num, province)
    })
    // ip-rules
    val ipRulesArray = ipRulesRdd.collect()
    // broadcast-rules
    val ipRulesBroadcast = sc.broadcast(ipRulesArray)
    // data
    val ipsRDD = sc.textFile("   ").map( line => {
      val fields  = line.split("\\|")
      fields(1)
    })
    // result
    val result = ipsRDD.map( ip => {
      val ipNum = ip2Long(ip)
      val index = binarySearch(ipRulesBroadcast.value, ipNum)
      val info = ipRulesBroadcast.value(index)
      info
    }).map( t => (t._3, 1)).reduceByKey(_+_)

    result.foreachPartition(data2MySQL)
    println(result.collect().toBuffer)
    sc.stop()
  }
}

讀取Mysql資料庫的內容

import java.sql.DriverManager
import org.apache.spark.rdd.JdbcRDD
import org.apache.spark.{SparkConf, SparkContext}

object JdbcRddDemo {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("JdbcRddDemo").setMaster("local[2]")
    val sc = new SparkContext(conf)
    val connection =() => {
      Class.forName("com.mysql.jdbc.Driver").newInstance()
      DriverManager.getConnection("jdbc:mysql://localhost:3306/home/hadoop", "hadoop", "000000fzx")
    }
    val jdbcRdd = new JdbcRDD (
      sc,
      connection,
      "SELECT * FROM table_1 where id>= ? AND id <= ?",
      1, 4, 2,
      rs => {
        val id = rs.getInt(1)
        val code = rs.getString(2)
        (id, code)
      }
    )
    val jrdd = jdbcRdd.collect()
    println(jdbcRdd.collect().toBuffer)
    sc.stop()
  }
}