Search code examples
javajsontext-parsing

Fast parsing of JSON with large skippable regions


I have a JSON that has a following structure (real world example is here https://gist.github.com/PavelPenkov/3432fe522e02aa3a8a597020d4ee7361):

{
  "metadata": { /* Huge TYPED object */ },
  "payload": { /* Small flat UNTYPED object */
    "field_1": 1
    "field_2": "Alice"
  }
}

I want to extract payload part as fast as possible, the file is huge and parsing it into case class is rather slow (5000 op/s on my laptop). So far I've tried

  1. Parse the whole document into case class with Jackson.

  2. Parse into AST with Jackson and extract only payload field - slightly faster.

  3. scala-jsoniter while it probably can parse the typed part faster, it's unable to parse untyped fields by design.

Are there any other options accessible from Java or (preferrably) Scala?


Solution

  • Skipping of unwanted JSON value is where jsoniter-scala shines. Yes, it doesn't provide AST model for JSON, but you can build it for yourself or use one provided by a 3rd party library. Here is an example of the custom codec for circe AST:

    package io.circe
    
    import java.util
    
    import com.github.plokhotnyuk.jsoniter_scala.core._
    import io.circe.Json._
    
    object CirceJsoniter {
      implicit val codec: JsonValueCodec[Json] = new JsonValueCodec[Json] {
        override def decodeValue(in: JsonReader, default: Json): Json = {
          var b = in.nextToken()
          if (b == 'n') in.readNullOrError(default, "expected `null` value")
          else if (b == '"') {
            in.rollbackToken()
            new JString(in.readString(null))
          } else if (b == 'f' || b == 't') {
            in.rollbackToken()
            if (in.readBoolean()) Json.True
            else Json.False
          } else if ((b >= '0' && b <= '9') || b == '-') {
            new JNumber({
              in.rollbackToken()
              in.setMark() // TODO: add in.readNumberAsString() to Core API of jsoniter-scala
              try {
                do b = in.nextByte()
                while (b >= '0' && b <= '9')
              } catch { case _: JsonReaderException => /* ignore end of input error */} finally in.rollbackToMark()
              if (b == '.' || b == 'e' || b == 'E') new JsonDouble(in.readDouble())
              else new JsonLong(in.readLong())
            })
          } else if (b == '[') {
            new JArray(if (in.isNextToken(']')) Vector.empty
            else {
              in.rollbackToken()
              var x = new Array[Json](4)
              var i = 0
              do {
                if (i == x.length) x = java.util.Arrays.copyOf(x, i << 1)
                x(i) = decodeValue(in, default)
                i += 1
              } while (in.isNextToken(','))
              (if (in.isCurrentToken(']'))
                if (i == x.length) x
                else java.util.Arrays.copyOf(x, i)
              else in.arrayEndOrCommaError()).to[Vector]
            })
          } else if (b == '{') {
            new JObject(if (in.isNextToken('}')) JsonObject.empty
            else {
              val x = new util.LinkedHashMap[String, Json]
              in.rollbackToken()
              do x.put(in.readKeyAsString(), decodeValue(in, default))
              while (in.isNextToken(','))
              if (!in.isCurrentToken('}')) in.objectEndOrCommaError()
              JsonObject.fromLinkedHashMap(x)
            })
          } else in.decodeError("expected JSON value")
        }
    
        override def encodeValue(x: Json, out: JsonWriter): Unit = x match {
          case JNull => out.writeNull()
          case JString(s) => out.writeVal(s)
          case JBoolean(b) => out.writeVal(b)
          case JNumber(n) => n match {
            case JsonLong(l) => out.writeVal(l)
            case _ => out.writeVal(n.toDouble)
          }
          case JArray(a) =>
            out.writeArrayStart()
            a.foreach(v => encodeValue(v, out))
            out.writeArrayEnd()
          case JObject(o) =>
            out.writeObjectStart()
            o.toIterable.foreach { case (k, v) =>
              out.writeKey(k)
              encodeValue(v, out)
            }
            out.writeObjectEnd()
        }
    
        override def nullValue: Json = Json.Null
      }
    }
    

    Another option, if you need just extraction of bytes of payload values, then you can use code like this to do it with rate ~300000 messages per second for the provided sample:

    import com.github.plokhotnyuk.jsoniter_scala.core._
    import com.github.plokhotnyuk.jsoniter_scala.macros._
    import java.nio.charset.StandardCharsets.UTF_8
    import java.util.concurrent.TimeUnit
    import org.openjdk.jmh.annotations._
    import scala.reflect.io.Streamable
    import scala.util.hashing.MurmurHash3
    
    case class Payload private(bs: Array[Byte]) {
      def this(s: String) = this(s.getBytes(UTF_8))
    
      override lazy val hashCode: Int = MurmurHash3.arrayHash(bs)
    
      override def equals(obj: Any): Boolean = obj match {
        case that: Payload => java.util.Arrays.equals(bs, that.bs)
        case _ => false
      }
    
      override def toString: String = new String(bs, UTF_8)
    }
    
    object Payload {
      def apply(s: String) = new Payload(s.getBytes)
    
      implicit val codec: JsonValueCodec[Payload] = new JsonValueCodec[Payload] {
        override def decodeValue(in: JsonReader, default: Payload): Payload = new Payload(in.readRawValAsBytes())
    
        override def encodeValue(x: Payload, out: JsonWriter): Unit = out.writeRawVal(x.bs)
    
        override val nullValue: Payload = new Payload(new Array[Byte](0))
      }
    }
    
    case class MessageWithPayload(payload: Payload)
    
    object MessageWithPayload {
      implicit val codec: JsonValueCodec[MessageWithPayload] = JsonCodecMaker.make(CodecMakerConfig())
    
      val jsonBytes: Array[Byte] = Streamable.bytes(getClass.getResourceAsStream("debezium.json"))
    }
    
    @State(Scope.Thread)
    @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
    @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
    @Fork(value = 1, jvmArgs = Array(
    "-server",
    "-Xms2g",
    "-Xmx2g",
    "-XX:NewSize=1g",
    "-XX:MaxNewSize=1g",
    "-XX:InitialCodeCacheSize=512m",
    "-XX:ReservedCodeCacheSize=512m",
    "-XX:+UseParallelGC",
    "-XX:-UseBiasedLocking",
    "-XX:+AlwaysPreTouch"
    ))
    @BenchmarkMode(Array(Mode.Throughput))
    @OutputTimeUnit(TimeUnit.SECONDS)
    class ExtractPayloadReading {
      @Benchmark
      def jsoniterScala(): MessageWithPayload = readFromArray[MessageWithPayload](MessageWithPayload.jsonBytes)
    }