Search code examples
pythonjsonapache-sparkpysparkaccumulator

Pyspark - checking Json format using accumulator


How do I check JSON file is corrupted like missing {, }, commas or wrong datatype. I am trying to achieve by using accumulator because process runs on multiple executors.

spark_config = SparkConf().setAppName(application_name)
ss = SparkSession.builder.config(conf=spark_config).getOrCreate()

class StringAccumulatorParam(AccumulatorParam):
  def zero(self, v):
      return []
  def addInPlace(self, variable, value):
      variable.append(value)
      return variable
errorCount = ss.sparkContext.accumulator(0)
errorValues = ss.sparkContext.accumulator("", StringAccumulatorParam())

newSchema = StructType([
    StructField("id", IntegerType(), True),
    StructField("name", StringType(), True)
    StructField("status", BooleanType(), True)])

errorDF = ss.read.json("/Users/test.jsonl")
errorDF2 = ss.createDataFrame(errorDF, newSchema).cache()

def checkErrorCount(row):
   global errorCount
   errorDF2["id"] = row. newSchema["id"]
      errorCount.add(1)
      errorValues.add(errorDF2["id"])

errorDF.foreach(lambda x: checkErrorCount(x))
print("{} rows had questionable values.".format(errorCount.value))

ss.stop()

Here is corrupt JSON file -

{"name":"Standards1","id":90,"status":true}
{"name":"Standards2","id":91
{"name":"Standards3","id":92,"status":true}
{"name":781,"id":93,"status":true}

Solution

  • I had a play with this and came up with the following.

    Of the 2 solutions, I think the difference of counts will be faster since it will use native Spark JSON processing.

    The UDF solution will do the JSON parsing in Python, meaning you have to pay the cost of transferring each file line from Java to Python so will probably be slower.

    import json
    from pyspark import SparkConf
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import sum, udf
    from pyspark.sql.types import LongType
    
    application_name = 'Count bad JSON lines'
    spark_config = SparkConf().setAppName(application_name)
    ss = SparkSession.builder.config(conf=spark_config).getOrCreate()
    
    # Difference of counts solution
    input_path = '/baddata.json'
    total_lines = ss.read.text(input_path).count()
    good_lines = ss.read.option('mode', 'DROPMALFORMED').json(input_path).count()
    bad_lines = total_lines - good_lines
    print('Found {} bad JSON lines in data'.format(bad_lines))
    
    # Parse JSON with UDF solution
    def is_bad(line):
        try:
            json.loads(line)
            return 0
        except ValueError:
            return 1
    
    is_bad_udf = udf(is_bad, LongType())
    lines = ss.read.text(input_path)
    bad_sum = lines.select(sum(is_bad_udf('value'))).collect()[0][0]
    print('Got {} bad lines'.format(bad_sum))
    
    ss.stop()