Search code examples
pysparkjsonpath

json path not working as expected when used with get_json_object


TLDR; The following JSON path isn't working for me when used with pyspark.sql.functions.get_json_object.

$.Blocks[?(@.Type=='LINE')].Confidence

Long version...

I want to group by arrays within a single row

For example, for the structure below

root
|--id: string
|--payload: string

the value of payload is a String representing a block of json that looks like the structure below

{
        "Blocks": [
            {
                "Type": "LINE",
                "Confidence": 90
            },
            {
                "Type": "LINE",
                "Confidence": 98
            },
            {
                "Type": "WORD",
                "Confidence": 99
            },
            {
                "Type": "PAGE",
                "Confidence": 97
            },
            {
                "Type": "PAGE",
                "Confidence": 89
            },
            {
                "Type": "WORD",
                "Confidence": 99
            }
        ]
    }

I want to aggregate all of the confidence by type so we get the following new column...

{
    "id": 12345,
    "payload": "..."
    "confidence": [
        {
            "Type": "WORD",
            "Confidence": [
                99,
                99
            ]
        },
        {
            "Type": "PAGE",
            "Confidence": [
                97,
                89
            ]
        },
        {
            "Type": "LINE",
            "Confidence": [
                90,
                98
            ]
        }
    ]
}

To do this I plan on using get_json_object(...) to extract confidences for each type of block.

For example...

get_json_object(col("payload"), "$.Blocks[?(@.Type=='LINE')].Confidence")

But $.Blocks[?(@.Type=='LINE')].Confidence keeps returning null. Why is that?

I verified the json path works by testing on https://jsonpath.curiousconcept.com/# against the sample payload json above and got the following result...

[
   90,
   98
]

If using the path above isn't an option how would one go about aggregating this?

Below is the full code sample. I expect the first .show() to print out [90, 98] in the confidence column.

from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, StringType, StructType, IntegerType
from pyspark.sql.functions import get_json_object, col


def main():
    spark = SparkSession.builder.appName('test_session').getOrCreate()
    df = spark.createDataFrame([
        (
            12345,  # id
            """
{
        "Blocks": [
            {
                "Type": "LINE",
                "Confidence": 90
            },
            {
                "Type": "LINE",
                "Confidence": 98
            },
            {
                "Type": "WORD",
                "Confidence": 99
            },
            {
                "Type": "PAGE",
                "Confidence": 97
            },
            {
                "Type": "PAGE",
                "Confidence": 89
            },
            {
                "Type": "WORD",
                "Confidence": 99
            }
        ]
    }

            """  # payload
        )
    ],
        StructType(
            [
                StructField("id", IntegerType(), True),
                StructField("payload", StringType(), True)
            ])
    )
    
    # this prints out null (why?)
    df.withColumn("confidence", get_json_object(col("payload"), "$.Blocks[?(@.Type=='LINE')].Confidence")).show()
    
    # this prints out the correct values, [90,98,99,97,89,99]
    df.withColumn("confidence", get_json_object(col("payload"), "$.Blocks[*].Confidence")).show()


if __name__ == "__main__":
    main()


Solution

  • There is no official document on how Spark parse JSON path, but based on its source code, looks like it does not support @ as current object. In fact it supports very limited syntax:

    // parse `[*]` and `[123]` subscripts
    // parse `.name` or `['name']` child expressions
    // child wildcards: `..`, `.*` or `['*']`
    

    So if you're open with another approach, here it is with pre-defined schema and functions like from_json, explode, collect_list:

    schema = T.StructType([
        T.StructField('Blocks', T.ArrayType(T.StructType([
            T.StructField('Type', T.StringType()),
            T.StructField('Confidence', T.IntegerType())
        ])))
    ])
    
    (df
        .withColumn('json', F.from_json('payload', schema))
        .withColumn('block', F.explode('json.blocks'))
        .select('id', 'block.*')
        .groupBy('id', 'Type')
        .agg(F.collect_list('Confidence').alias('confidence'))
        .show(10, False)
    )
    
    # +-----+----+----------+
    # |id   |Type|confidence|
    # +-----+----+----------+
    # |12345|PAGE|[97, 89]  |
    # |12345|WORD|[99, 99]  |
    # |12345|LINE|[90, 98]  |
    # +-----+----+----------+