1. 程式人生 > >使用SparkSQL內建函式介面開發StructType/Row轉Json函式

使用SparkSQL內建函式介面開發StructType/Row轉Json函式

需求

將DataFrame中的StructType型別欄位下的所有內容轉換為Json字串。

spark版本: 1.6.1

思路

  • DataFrame有toJSON方法,可將每個Row都轉為一個Json字串,並返回RDD[String]
  • DataFrame.write.json方法,可將資料寫為Json格式檔案

跟蹤上述兩處程式碼,發現最終都會呼叫Spark原始碼中的org.apache.spark.sql.execution.datasources.json.JacksonGenerator類,使用Jackson,根據傳入的StructType、JsonGenerator和InternalRow,生成Json字串。

開發

我們的函式只需傳入一個引數,就是需要轉換的列,因此需要實現org.apache.spark.sql.catalyst.expressions包下的UnaryExpression。

後續對功能進行了擴充套件,不是StructType型別的輸入也可以轉換。

package org.apache.spark.sql.catalyst.expressions


import java.io.CharArrayWriter


import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenContext
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratedExpressionCode
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.Metadata
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType


import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.unsafe.types.UTF8String


/**
 * 將StructType型別的欄位轉換為Json String
 * @author yizhu.sun 2016年8月30日
 */
case class Json(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {


  override def dataType: DataType = StringType
  override def inputTypes: Seq[DataType] = Seq(child.dataType)


  val inputStructType: StructType = child.dataType match {
    case st: StructType => st
    case _ => StructType(Seq(StructField("col", child.dataType, child.nullable, Metadata.empty)))
  }


  override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess


  // 參照 org.apache.spark.sql.DataFrame.toJSON
  // 參照 org.apache.spark.sql.execution.datasources.json.JsonOutputWriter.writeInternal
  protected override def nullSafeEval(data: Any): UTF8String = {
    val writer = new CharArrayWriter
    val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
    val internalRow = child.dataType match {
      case _: StructType => data.asInstanceOf[InternalRow]
      case _ => InternalRow(data)
    }
    JacksonGenerator(inputStructType, gen)(internalRow)
    gen.flush
    gen.close
    val json = writer.toString
    UTF8String.fromString(
      child.dataType match {
        case _: StructType => json
        case _ => json.substring(json.indexOf(":") + 1, json.lastIndexOf("}"))
      })
  }


  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val writer = ctx.freshName("writer")
    val gen = ctx.freshName("gen")
    val st = ctx.freshName("st")
    val json = ctx.freshName("json")
    val typeJson = inputStructType.json
    def getDataExp(data: Any) =
      child.dataType match {
        case _: StructType => s"$data"
        case _ => s"new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{$data})"
      }
    def formatJson(json: String) =
      child.dataType match {
        case _: StructType => s"$json"
        case _ => s"""$json.substring($json.indexOf(":") + 1, $json.lastIndexOf("}"))"""
      }
    nullSafeCodeGen(ctx, ev, (row) => {
      s"""
        | com.fasterxml.jackson.core.JsonGenerator $gen = null;
        | try {
        |   org.apache.spark.sql.types.StructType $st = ${classOf[Json].getName}.getStructType("${typeJson.replace("\"", "\\\"")}");
        |   java.io.CharArrayWriter $writer = new java.io.CharArrayWriter();
        |   $gen = new com.fasterxml.jackson.core.JsonFactory().createGenerator($writer).setRootValueSeparator(null);
        |   org.apache.spark.sql.execution.datasources.json.JacksonGenerator.apply($st, $gen, ${getDataExp(row)});
        |   $gen.flush();
        |   String $json = $writer.toString();
        |   ${ev.value} = UTF8String.fromString(${formatJson(json)});
        | } catch (Exception e) {
        |   ${ev.isNull} = true;
        | } finally {
        |   if ($gen != null) $gen.close();
        | }
       """.stripMargin
    })
  }


}


object Json {


  val structTypeCache = collection.mutable.Map[String, StructType]() // [json, type]


  def getStructType(json: String): StructType = {
    structTypeCache.getOrElseUpdate(json, {
      println(">>>>> get StructType from json:")
      println(json)
      DataType.fromJson(json).asInstanceOf[StructType]
    })
  }


}


註冊

注意,SQLContext.functionRegistry的可見性為protected[sql]

val (name, (info, builder)) = FunctionRegistry.expression[Json]("json")

sqlContext.functionRegistry.registerFunction(name, info, builder)

測試

val subSchema = StructType(Array(
  StructField("a", StringType, true),
  StructField("b", StringType, true),
  StructField("c", IntegerType, true)))

val schema = StructType(Array(
  StructField("x", subSchema, true)))

val rdd = sc.makeRDD(Seq(Row(Row("12", null, 123)), Row(Row(null, "2222", null))))

val df = sqlContext.createDataFrame(rdd, schema)

df.registerTempTable("df")

import sqlContext.sql

sql("select x, x.a from df").show
sql("select x, x.a from df").printSchema
sql("select json(x), json(x.a) from df").show
sql("select json(x), json(x.a) from df").printSchema

結果

+----------------+----+
|x               |a   |
+----------------+----+
|[12,null,123]   |12  |
|[null,2222,null]|null|
+----------------+----+

root
 |-- x: struct (nullable = true)
 |    |-- a: string (nullable = true)
 |    |-- b: string (nullable = true)
 |    |-- c: integer (nullable = true)
 |-- a: string (nullable = true)

>>>>> get StructType from json:
{"type":"struct","fields":[{"name":"a","type":"string","nullable":true,"metadata":{}},{"name":"b","type":"string","nullable":true,"metadata":{}},{"name":"c","type":"integer","nullable":true,"metadata":{}}]}
>>>>> get StructType from json:
{"type":"struct","fields":[{"name":"col","type":"string","nullable":true,"metadata":{}}]}

+------------------+----+
|_c0               |_c1 |
+------------------+----+
|{"a":"12","c":123}|"12"|
|{"b":"2222"}      |null|
+------------------+----+

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)


需要注意的點

  1. 使用SparkSQL自定義函式一般有兩種方法,一種是使用開放的api註冊簡單函式,即呼叫sqlContext.udf.register方法。另一種就是使用SparkSQL內建函式的註冊方法(本例就是使用的這種方法)。前者優勢是開發簡單,但是實現不了較為複雜的功能,例如本例中需要獲取傳入的InternalRow的StructType,或者需要實現類似 def fun(arg: Seq[T]): T 這種泛型相關的功能(sqlContext.udf.register的註冊方式無法註冊返回值為Any的函式)。
  2. 本例中實現genCode函式時遇到了困難,即需要在生成的Java程式碼中構建StructType物件。這個最終通過序列化的思路解決,先使用StructType.json方法將StructType物件序列化為String,然後在Java程式碼中呼叫DataType.fromJson反序列化為StructType物件。