1. 程式人生 > >Spark 實現MySQL update操作

Spark 實現MySQL update操作

背景

目前 spark 對 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore幾種表級別的模式,有時我們需要對錶進行行級別的操作,比如update。即我們需要構造這樣的語句出來:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;

需求:我們的目的是既不影響以前寫的程式碼,又不引入新的API,只需新加一個配置如:savemode=update這樣的形式來實現。

實踐

要滿足以上需求,肯定是要改原始碼的,首先建立自己的saveMode,只是新加了一個Update而已:

public enum I4SaveMode {
    Append,
    Overwrite,
    ErrorIfExists,
    Ignore,
    Update
}

JDBC資料來源的相關實現主要在JdbcRelationProvider裡,我們需要關注的是createRelation方法,我們可以在此方法裡,把SaveMode改成我們自己的mode,並把mode帶到saveTable方法裡,所以改造後的方法如下(改了的地方都有註釋):

   override def createRelation(
                                   sqlContext: SQLContext,
                                   mode: SaveMode,
                                   parameters: Map[String, String],
                                   df: DataFrame): BaseRelation = {
        val
options = new JDBCOptions(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis // 替換成自己的saveMode var saveMode = mode match { case SaveMode.Overwrite => I4SaveMode.Overwrite case SaveMode.Append => I4SaveMode.Append case
SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists case SaveMode.Ignore => I4SaveMode.Ignore } //重點在這裡,檢查是否有saveMode=update的引數,並設為對應的模式 val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2)) if(parameterLower.keySet.contains("savemode")){ saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode } val conn = JdbcUtils.createConnectionFactory(options)() try { val tableExists = JdbcUtils.tableExists(conn, options) if (tableExists) { saveMode match { case I4SaveMode.Overwrite => if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { // In this case, we should truncate table and then load. truncateTable(conn, options.table) val tableSchema = JdbcUtils.getSchemaOption(conn, options) saveTable(df, tableSchema, isCaseSensitive, options, saveMode) } else { ...... }

接下來就是saveTable方法:

def saveTable(
      df: DataFrame,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      options: JDBCOptions,
      mode: I4SaveMode): Unit = { 
    ......
    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    .....
    repartitionedDF.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
    )
  }

這裡通過getInsertStatement方法構造sql語句,接著遍歷每個分割槽進行對應的save操作,我們先看是構造語句是怎麼改的(改了的地方都有註釋):

def getInsertStatement(
      table: String,
      rddSchema: StructType,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      dialect: JdbcDialect,
      mode: I4SaveMode): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      } 
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    } 
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    // s"INSERT INTO $table ($columns) VALUES ($placeholders)"
   //若為update模式需要單獨構造
    mode match {
            case I4SaveMode.Update ⇒
                val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name ⇒ s"$name=?").mkString(",")
                s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
            case _ ⇒ s"INSERT INTO $table ($columns) VALUES ($placeholders)"
        }
  }

只需判斷是否是update模式來構造對應的 sql語句,接著主要是看 savePartition 方法,看看具體是怎麼儲存的:

 def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      insertStmt: String,
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = conn.prepareStatement(insertStmt)
      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          // If there is no cause already, set 'next exception' as cause. If cause is null,
          // it *may* be because no cause was set yet
          if (e.getCause == null) {
            try {
              e.initCause(cause)
            } catch {
              // Or it may be null because the cause *was* explicitly initialized, to *null*,
              // in which case this fails. There is no other way to detect it.
              // addSuppressed in this case as well.
              case _: IllegalStateException => e.addSuppressed(cause)
            }
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

大體思想就是在迭代該分割槽資料進行插入之前就先根據資料的schema設定好了插入模板setters,迭代的時候只需將此模板應用到每一行資料上就行了,避免了每一行都需要去判斷資料型別。
在非update的情況下:insert into tb (id,name,age) values (?,?,?)
在update情況下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即佔位符多了一倍,在update模式下進行寫入的時候需要向PreparedStatement多喂一遍資料。原本的makeSetter方法如下:

private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))
    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))
    ...
  }

我們只需要再加一個相對位置引數offset來控制,即改造成:

private def makeSetter(
       conn: Connection,
       dialect: JdbcDialect,
       dataType: DataType): JDBCValueSetter = dataType match {
     case IntegerType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setInt(pos + 1, row.getInt(pos - offset))
     case LongType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setLong(pos + 1, row.getLong(pos - offset))
    ...

在非update模式下offset就為0,在update模式下在沒有超過numFields時offset為0,超過numFileds時offset為numFields。改造後的savePartition方法為:

def savePartition(
                 getConnection: () => Connection,
                 table: String,
                 iterator: Iterator[Row],
                 rddSchema: StructType,
                 insertStmt: String,
                 batchSize: Int,
                 dialect: JdbcDialect,
                 isolationLevel: Int,
                 mode: I4SaveMode): Iterator[Byte] = {
    ...
    //判斷是否為update
    val isUpdateMode = mode == I4SaveMode.Update
    val stmt = conn.prepareStatement(insertStmt)
    val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
    val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    val length = rddSchema.fields.length
    // update模式下佔位符是2倍
    val numFields = if (isUpdateMode) length * 2 else length
    val midField = numFields / 2
    try {
        var rowCount = 0
        while (iterator.hasNext) {
            val row = iterator.next()
            var i = 0
            while (i < numFields) {
                if (isUpdateMode) {
                    // update模式下未超過欄位長度,offset為0
                    i < midField match {
                        case true ?
                            if (row.isNullAt(i)) {
                                stmt.setNull(i + 1, nullTypes(i))
                            } else {
                                setters(i).apply(stmt, row, i, 0)
                            }
                        // update模式下超過欄位長度,offset為midField,即欄位長度
                        case false ?
                            if (row.isNullAt(i - midField)) {
                                stmt.setNull(i + 1, nullTypes(i - midField))
                            } else {
                                setters(i - midField).apply(stmt, row, i, midField)
                            }
                    }

                } else {
                    if (row.isNullAt(i)) {
                        stmt.setNull(i + 1, nullTypes(i))
                    } else {
                        setters(i).apply(stmt, row, i, 0)
                    }
                }
                i = i + 1
            }
          ...

改造好原始碼後,需要重新編譯打包,替換掉線上對應的jar即可。其實這裡有個捷徑,自己建立相同的包名,改好原始碼後打成jar包,把該jar裡面的class檔案替換掉線上jar裡面對應的那些class檔案就可以了。

如何使用

若需要使用到update模式:

df.write.option("saveMode","update").jdbc(...)

參考