Search code examples
pythonrecursionpyspark

Recursively adding columns to pyspark dataframe nested arrays


I'm working with a pyspark DataFrame that contains multiple levels of nested arrays of structs. My goal is to add an array's hash column + record's top level hash column to each nested array. To achieve that for all nested arrays I need to use recursion since I do not know how nested the array can be.

So for this example schema

schema = StructType([
    StructField("name", StringType()),
    StructField("experience", ArrayType(StructType([
        StructField("role", StringType()),
        StructField("duration", StringType()),
        StructField("company", ArrayType(StructType([
            StructField("company_name", StringType()),
            StructField("location", StringType())
        ])))
    ])))
])

The desired output schema would look like this:

hashed_schema = StructType([
    StructField("name", StringType()),
    StructField("experience", ArrayType(StructType([
        StructField("role", StringType()),
        StructField("duration", StringType()),
        StructField("experience_hash", StringType()),  # Added hash for the experience collection
        StructField("company", ArrayType(StructType([
            StructField("company_name", StringType()),
            StructField("location", StringType()),
            StructField("company_hash", StringType())  # Added hash for the company collection
        ])))
    ]))),
    StructField("employee_hash", StringType()),  # Added hash for the entire record
])

I have tried to write a code with recursion that would iterate trough each nested array and hash its columns. While it seems to work for 1st level nested arrays, the recursion part does not work, I get an error that the recursion is too deep.


def hash_for_level(level_path):
    return md5(concat_ws("_", *[lit(elem) for elem in level_path]))

def add_hash_columns(df, level_path, current_struct, root_hash_col=None):
    # If this is the root level, create the root hash
    if not level_path and root_hash_col is None:
        root_hash_col = 'employee_hash'
        df = df.withColumn(root_hash_col, hash_for_level(['employee']))
    
    # Traverse the current structure and add hash columns
    for field in current_struct.fields:
        new_level_path = level_path + [field.name]
        # If the field is an array of structs, add a hash for each element in the array
        if isinstance(field.dataType, ArrayType):
            nested_struct = field.dataType.elementType
            hash_expr = transform(
                col('.'.join(level_path + [field.name])),
                lambda x: x.withField(new_level_path[-1] + '_hash', hash_for_level(new_level_path))
                    .withField(root_hash_col, col(root_hash_col))  # Include the root hash
            )
            # Add the hash column to the array elements
            df = df.withColumn('.'.join(level_path + [field.name]), hash_expr)
            # Recursion call to apply the same logic for nested arrays
            df = add_hash_columns(df, new_level_path, nested_struct, root_hash_col)
            
    # Add a hash column at the current level
    if level_path:
        #print("Level path:", level_path)
        hash_col_name = '_'.join(level_path) + '_hash'
        df = df.withColumn(hash_col_name, hash_for_level(level_path))
        if root_hash_col:
            # Ensure the root hash is included at each struct level
            df = df.withColumn(root_hash_col, col(root_hash_col))
            
    return df

df = spark.createDataFrame([], schema)
df = add_hash_columns(df, [], df.schema)
df

Solution

  • The approach in the question using transform on the arrays worked for me. But instead of adding the hash columns with withColumn I generate a (large) set of column expressions including the hash columns and use this list in a select call.

    from pyspark.sql import functions as F
    from pyspark.sql import types as T
    
    # handling the first level
    def add_hashes(df):
        hashcols = []
        for field in df.schema:
            hashcols.append(df[field.name])
            if isinstance(field.dataType, T.ArrayType):
                yield transform_array(field.name, df[field.name], df.schema[field.name].dataType.elementType)
            else:
                yield df[field.name]
        yield F.hash(*hashcols).alias('hash')
    
    # special handling for adding the hash column to arrays
    def transform_array(colname: str, col: F.Column, elementOfStruct: T.StructType):
        fields = elementOfStruct.fields
        hashcols = []
        def process_struct_elements(x: F.Column):
            for field in fields:
                name = field.name
                dataType = field.dataType
                hashcols.append(x[name])
                if isinstance(dataType, T.ArrayType):
                    yield transform_array(name, x[name], dataType.elementType)
                else:
                    yield x[name].alias(name)
            yield F.hash(*hashcols).alias('hash')
        return F.transform(col, lambda x: F.struct(*process_struct_elements(x))).alias(colname)
    
    # starting the process
    cols=list(add_hashes(df))
    df.select(cols).printSchema()
    df.select(cols).show(truncate=True)
    

    Output:

    root
     |-- experience: array (nullable = true)
     |    |-- element: struct (containsNull = false)
     |    |    |-- company: array (nullable = true)
     |    |    |    |-- element: struct (containsNull = false)
     |    |    |    |    |-- company_name: string (nullable = true)
     |    |    |    |    |-- location: string (nullable = true)
     |    |    |    |    |-- hash: integer (nullable = false)
     |    |    |-- duration: string (nullable = true)
     |    |    |-- role: string (nullable = true)
     |    |    |-- hash: integer (nullable = false)
     |-- name: string (nullable = true)
     |-- hash: integer (nullable = false)
    
    +--------------------+----+-----------+
    |          experience|name|       hash|
    +--------------------+----+-----------+
    |[{[{cn111, loc111...|  n1|-1931528302|
    |[{[{cn211, loc211...|  n2|  312789015|
    +--------------------+----+-----------+
    

    Remarks:

    • the code assumes that arrays only contain struct fields (like in the question)
    • I use hash instead of md5
    • for me using generators simplifies the handling of recursions. But that's probably a matter of personal taste.