Search code examples
apache-sparkpyspark

Apply function to every field of a DataFrame with nested structs and arrays


I want to write a function that will receive a Spark DataFrame (PySpark) and for each field of type Integer, it will apply a custom function, for this example, let's say I want to add 1 to the original value. This function should be able to handle any nested structures, including arrays.

Here's an example:

data = {
    'top_level_field':{
        'array_field': [
            {
                'inner_array_field': [1, 2, 3],
                'inner_array_of_structs': [
                    {
                        'field_1': 2,
                        'field_2': 3
                    }
                ]
            },
            {
                'inner_array_field': [1, 2, 3],
                'inner_array_of_structs': [
                    {
                        'field_1': 2,
                        'field_2': 3,
                        'field_3': [1, 2, 3]
                    }
                ]
            }
        ]
    }
}

Input example

The output should be:

data = {
    'top_level_field':{
        'array_field': [
            {
                'inner_array_field': [2, 3, 4],
                'inner_array_of_structs': [
                    {
                        'field_1': 3,
                        'field_2': 4
                    }
                ]
            },
            {
                'inner_array_field': [2, 3, 4],
                'inner_array_of_structs': [
                    {
                        'field_1': 3,
                        'field_2': 4,
                        'field_3': [2, 3, 4]
                    }
                ]
            }
        ]
    }
}

If possible, I'd like to avoid UDFs, but if not, that will work as well.

I'm trying to write a recursive function that will traverse the schema looking for every IntegerType field and then apply the "plus_one" function, but I'm having trouble getting it right


Solution

  • UDF + recursive approach

    def plus_one(v):
        return v + 1
    
    def apply(value):
        def _apply(value):
            if isinstance(value, (int, float)):
                return plus_one(value)
            elif isinstance(value, list):
                return list(map(_apply, value))
            elif isinstance(value, T.Row):
                return {k: _apply(v) for k, v in value.asDict().items()}
            else:
                return value
        return _apply(value)
    
    dtype = dict(df.dtypes)['top_level_field']
    result = df.withColumn('top_level_field', F.udf(apply, returnType=dtype)('top_level_field'))
    

    df.show()
    +-----------------------------------------------------------------+
    |top_level_field                                                  |
    +-----------------------------------------------------------------+
    |{[{[1, 2, 3], [{2, 3, NULL}]}, {[1, 2, 3], [{2, 3, [1, 2, 3]}]}]}|
    +-----------------------------------------------------------------+
    
    result.show()
    +-----------------------------------------------------------------+
    |top_level_field                                                  |
    +-----------------------------------------------------------------+
    |{[{[2, 3, 4], [{3, 4, NULL}]}, {[2, 3, 4], [{3, 4, [2, 3, 4]}]}]}|
    +-----------------------------------------------------------------+