Spark用法:關於寫入Mysql表Append Mode資料被清空的解決辦法
阿新 • • 發佈:2018-12-17
前提:小編自己的環境是(CDH)Spark2.2.0 Scala2.11.8
起因:當使用Append追加寫入mysql型別的資料庫,spark預設是把之前存在的資料清空,然後再寫入資料;這讓我們很無語,明明是Append,你卻給我overwrite
解決:修改原始碼,重寫兩個類(只要把這兩個類放到自己專案中,無需修改spark底層原始碼)
1.JdbcUtils
原本是:if (mode == SaveMode.Append && tableExists) { truncateTable(conn, table) tableExists = true }
把truncateTable(conn, table)刪除即可
完整的:
import java.sql.{Connection, DriverManager, PreparedStatement} import java.util.Properties import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.jdbc.DriverWrapper import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SaveMode} import scala.util.Try /** * Util functions for JDBC tables. */ object JdbcUtils extends Logging { val mode = SaveMode.Append def jdbc(mode:SaveMode,url: String, df: DataFrame, table: String, connectionProperties: Properties): Unit = { val props = new Properties() props.putAll(connectionProperties) val conn = JdbcUtils.createConnection(url, props) try { var tableExists = JdbcUtils.tableExists(conn, table) if (mode == SaveMode.Ignore && tableExists) { return } if (mode == SaveMode.ErrorIfExists && tableExists) { sys.error(s"Table $table already exists.") } if (mode == SaveMode.Overwrite && tableExists) { truncateTable(conn, table) tableExists = true } if (mode == SaveMode.Append && tableExists) { //*********************把下面這一行註釋或刪除***************************** // truncateTable(conn, table) tableExists = true } // Create the table if the table didn't exist. if (!tableExists) { val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" conn.prepareStatement(sql).executeUpdate() } JdbcUtils.saveTable(df, url, table, props) } finally { conn.close() } } /** * Establishes a JDBC connection. */ def createConnection(url: String, connectionProperties: Properties): Connection = { JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)() } /** * Returns true if the table already exists in the JDBC database. */ def tableExists(conn: Connection, table: String): Boolean = { // Somewhat hacky, but there isn't a good way to identify whether a table exists for all // SQL database systems, considering "table" could also include the database name. Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess } /** * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String): Unit = { conn.prepareStatement(s"DROP TABLE $table").executeUpdate() } def truncateTable(conn: Connection, table: String): Unit = { conn.prepareStatement(s"TRUNCATE TABLE $table").executeUpdate() } /** * Returns a PreparedStatement that inserts a row into table via conn. */ def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { val fields = rddSchema.fields val fieldsSql = new StringBuilder(s"(") var i = 0; for (f <- fields) { fieldsSql.append(f.name) if (i == fields.length - 1) { fieldsSql.append(")") } else { fieldsSql.append(",") } i += 1 } val sql = new StringBuilder(s"INSERT INTO $table ") sql.append(fieldsSql.toString()) sql.append(" VALUES (") var fieldsLeft = rddSchema.fields.length while (fieldsLeft > 0) { sql.append("?") if (fieldsLeft > 1) sql.append(", ") else sql.append(")") fieldsLeft = fieldsLeft - 1 } //println(sql.toString()) conn.prepareStatement(sql.toString()) } /** * Saves a partition of a DataFrame to the JDBC database. This is done in * a single database transaction in order to avoid repeatedly inserting * data as much as possible. * * It is still theoretically possible for rows in a DataFrame to be * inserted into the database more than once if a stage somehow fails after * the commit occurs but before the stage can return successfully. * * This is not a closure inside saveTable() because apparently cosmetic * implementation changes elsewhere might easily render such a closure * non-Serializable. Instead, we explicitly close over all variables that * are used. */ def savePartition( getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = { val conn = getConnection() var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. val stmt = insertStatement(conn, table, rddSchema) try { while (iterator.hasNext) { val row = iterator.next() val numFields = rddSchema.fields.length var i = 0 while (i < numFields) { if (row.isNullAt(i)) { stmt.setNull(i + 1, nullTypes(i)) } else { rddSchema.fields(i).dataType match { case IntegerType => stmt.setInt(i + 1, row.getInt(i)) case LongType => stmt.setLong(i + 1, row.getLong(i)) case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) case ShortType => stmt.setInt(i + 1, row.getShort(i)) case ByteType => stmt.setInt(i + 1, row.getByte(i)) case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) case StringType => stmt.setString(i + 1, row.getString(i)) case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } } i = i + 1 } stmt.executeUpdate() } } finally { stmt.close() } conn.commit() committed = true } finally { if (!committed) { // The stage must fail. We got here through an exception path, so // let the exception through unless rollback() or close() want to // tell the user about another problem. conn.rollback() conn.close() } else { // The stage must succeed. We cannot propagate any exception close() might throw. try { conn.close() } catch { case e: Exception => logWarning("Transaction succeeded, but closing failed", e) } } } Array[Byte]().iterator } /** * Compute the schema string for this RDD. */ def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name val typ: String = dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( field.dataType match { case IntegerType => "INTEGER" case LongType => "BIGINT" case DoubleType => "DOUBLE PRECISION" case FloatType => "REAL" case ShortType => "INTEGER" case ByteType => "BYTE" case BooleanType => "BIT(1)" case StringType => "TEXT" case BinaryType => "BLOB" case TimestampType => "TIMESTAMP" case DateType => "DATE" case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") }) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") } } if (sb.length < 2) "" else sb.substring(2) } /** * Saves the RDD to the database in a single transaction. */ def saveTable( df: DataFrame, url: String, table: String, properties: Properties = new Properties()) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( field.dataType match { case IntegerType => java.sql.Types.INTEGER case LongType => java.sql.Types.BIGINT case DoubleType => java.sql.Types.DOUBLE case FloatType => java.sql.Types.REAL case ShortType => java.sql.Types.INTEGER case ByteType => java.sql.Types.INTEGER case BooleanType => java.sql.Types.BIT case StringType => java.sql.Types.CLOB case BinaryType => java.sql.Types.BLOB case TimestampType => java.sql.Types.TIMESTAMP case DateType => java.sql.Types.DATE case t: DecimalType => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") }) } val rddSchema = df.schema /* def createConnectionFactory(options: JDBCOptions): () => Connection = { val driverClass: String = options.driverClass () => { DriverRegistry.register(driverClass) val driver: Driver = DriverManager.getDrivers { case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d case d if d.getClass.getCanonicalName == driverClass => d }.getOrElse { throw new IllegalStateException( s"Did not find registered driver with class $driverClass") } driver.connect(options.url, options.asConnectionProperties) } }*/ val driver: String = getDriverClassName(url) val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) df.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes) } } def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName case driver => driver.getClass.getCanonicalName } }
2.JDBCRDD,這個類之所以也重寫 是因為上面的JdbcUtils 要用到,不加入會報錯:
刪除部分,只保留有用的,結果是:
/** * Created by Administrator on 2018/5/10 0010. */ import java.sql.{Connection, DriverManager} import java.util.Properties import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry private object JDBCRDD extends Logging { def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { if (driver != null) DriverRegistry.register(driver) } catch { case e: ClassNotFoundException => logWarning(s"Couldn't find class $driver", e) } DriverManager.getConnection(url, properties) } } }
最後就是呼叫拉:
df是DataFrame;
url="jdbc:mysql://*****:3306/databaseName;
JdbcUtils.jdbc(SaveMode.Append,url,df,"page_visit_day",prop)