Search code examples
pythonpandasjoinpysparkaggregate

Aggregate values from dataframe based on a on criteria if not null


I have trouble building a custom aggregation with the caveat that my join keys are different on each row. Can someone help me please?

I have been stuck with a problemm for some time. I have a huge dataframe with transactions which has a format close to this:

flat_data = {
    'year': [2022, 2022, 2022, 2023, 2023, 2023, 2023, 2023, 2023],
    'month': [1, 1, 2, 1, 2, 2, 3, 3, 3],
    'operator': ['A', 'A', 'B', 'A', 'B', 'B', 'C', 'C', 'C'],
    'value': [10, 15, 20, 8, 12, 15, 30, 40, 50],
    'attribute1': ['x', 'x', 'y', 'x', 'y', 'z', 'x', 'z', 'x'],
    'attribute2': ['apple', 'apple', 'banana', 'apple', 'banana', 'banana', 'apple', 'banana', 'banana'],
    'attribute3': ['dog', 'cat', 'dog', 'cat', 'rabbit', 'tutle', 'cat', 'dog', 'dog'],
}

I have over 80 attributes.

On the other hand I have a totals dataframe looking like this:

totals= {
    'year': [2022, 2022, 2023, 2023, 2023],
    'month': [1, 2, 1, 2, 3],
    'operator': ['A', 'B', 'A', 'B', 'C'],
    'id': ['id1', 'id2', 'id1', 'id2', 'id3'], 
    'attribute1': [None, 'y', 'x', 'z', 'x'],
    'attribute2': ['apple', None, 'apple', 'banana', 'banana'],
}

The totals dataframe has only attributes I can find the in flat_data but has an extra id. What I am trying to do is to get a result dataframe with year, month, operator id and value. For that I need to sum the values of all rows of flat that match the attributes of filter but only the non-null ones.

My output looks like:

result= {
    'year': [2022, 2022, 2023, 2023, 2023],
    'month': [1, 2, 1, 2, 3],
    'operator': ['A', 'B', 'A', 'B', 'C'],
    'id': ['id1', 'id2', 'id1', 'id2', 'id3'],
     'sum': [10, 15, 20, 8, 12, 15, 30, 40, 50],
}

where sum is a sum of all values of rows where the non null attributes match the id attributes.

For example id1 would match every row of 01/2002 with the same operator (oroperator A) with attribute2 = apple regardlless of attribute 1 (rows 1 and 2) so my total for id 1 for operator A for 01/2022 would be 25.

I tried looping through the rows but it is prone to error and memory greedy. I want to try to use pyspark but cannot find how to distribute the task. I have managed to do it on a row by row basis. Meaning a join on attributes and then groupby + sum. However where I am stuck is that in effect each row has its own set of join keys because of the null constraint (i.e. null in filter matches everything)and I cannot therefore generalise the approach.


Solution

  • I hope I understand the question correctly. Check out this solution:

    import pyspark.sql.functions as f
    from pyspark.sql.types import *
    from pyspark.sql import SparkSession
    from pyspark.sql.window import Window
    
    spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()
    
    flat_data = {
        'year': [2022, 2022, 2022, 2023, 2023, 2023, 2023, 2023, 2023],
        'month': [1, 1, 2, 1, 2, 2, 3, 3, 3],
        'operator': ['A', 'A', 'B', 'A', 'B', 'B', 'C', 'C', 'C'],
        'value': [10, 15, 20, 8, 12, 15, 30, 40, 50],
        'attribute1': ['x', 'x', 'y', 'x', 'y', 'z', 'x', 'z', 'x'],
        'attribute2': ['apple', 'apple', 'banana', 'apple', 'banana', 'banana', 'apple', 'banana', 'banana'],
        'attribute3': ['dog', 'cat', 'dog', 'cat', 'rabbit', 'tutle', 'cat', 'dog', 'dog'],
    }
    totals= {
        'year': [2022, 2022, 2023, 2023, 2023],
        'month': [1, 2, 1, 2, 3],
        'operator': ['A', 'B', 'A', 'B', 'C'],
        'id': ['id1', 'id2', 'id1', 'id2', 'id3'], 
        'attribute1': [None, 'y', 'x', 'z', 'x'],
        'attribute2': ['apple', None, 'apple', 'banana', 'banana'],
    }
    flat_data_df = spark.createDataFrame(list(zip(*flat_data.values())), list(flat_data.keys()))
    totals_df = spark.createDataFrame(list(zip(*totals.values())), list(totals.keys()))
    
    output_df = (
        flat_data_df.alias('flat')
        .join(
            totals_df.alias('total'), 
            (flat_data_df.year == totals_df.year) & 
            (flat_data_df.month == totals_df.month) & 
            (flat_data_df.operator == totals_df.operator) & 
            ((flat_data_df.attribute1 == totals_df.attribute1) | (totals_df.attribute1.isNull())) & 
            ((flat_data_df.attribute2 == totals_df.attribute2) | (totals_df.attribute2.isNull())), 
            "inner"
        )
        .select('flat.year', 'flat.month', 'flat.operator', 'total.id', 'flat.value') 
        .groupBy('year', 'month', 'operator', 'id')
        .agg(f.sum('value').alias('sum'))
    )
    
    output_df.show()
    

    and the output is:

    +----+-----+--------+---+---+                                                   
    |year|month|operator| id|sum|
    +----+-----+--------+---+---+
    |2022|    1|       A|id1| 25|
    |2022|    2|       B|id2| 20|
    |2023|    1|       A|id1|  8|
    |2023|    2|       B|id2| 15|
    |2023|    3|       C|id3| 50|
    +----+-----+--------+---+---+