I have trouble building a custom aggregation with the caveat that my join keys are different on each row. Can someone help me please?
I have been stuck with a problemm for some time. I have a huge dataframe with transactions which has a format close to this:
flat_data = {
'year': [2022, 2022, 2022, 2023, 2023, 2023, 2023, 2023, 2023],
'month': [1, 1, 2, 1, 2, 2, 3, 3, 3],
'operator': ['A', 'A', 'B', 'A', 'B', 'B', 'C', 'C', 'C'],
'value': [10, 15, 20, 8, 12, 15, 30, 40, 50],
'attribute1': ['x', 'x', 'y', 'x', 'y', 'z', 'x', 'z', 'x'],
'attribute2': ['apple', 'apple', 'banana', 'apple', 'banana', 'banana', 'apple', 'banana', 'banana'],
'attribute3': ['dog', 'cat', 'dog', 'cat', 'rabbit', 'tutle', 'cat', 'dog', 'dog'],
}
I have over 80 attributes.
On the other hand I have a totals dataframe looking like this:
totals= {
'year': [2022, 2022, 2023, 2023, 2023],
'month': [1, 2, 1, 2, 3],
'operator': ['A', 'B', 'A', 'B', 'C'],
'id': ['id1', 'id2', 'id1', 'id2', 'id3'],
'attribute1': [None, 'y', 'x', 'z', 'x'],
'attribute2': ['apple', None, 'apple', 'banana', 'banana'],
}
The totals dataframe has only attributes I can find the in flat_data but has an extra id. What I am trying to do is to get a result dataframe with year, month, operator id and value. For that I need to sum the values of all rows of flat that match the attributes of filter but only the non-null ones.
My output looks like:
result= {
'year': [2022, 2022, 2023, 2023, 2023],
'month': [1, 2, 1, 2, 3],
'operator': ['A', 'B', 'A', 'B', 'C'],
'id': ['id1', 'id2', 'id1', 'id2', 'id3'],
'sum': [10, 15, 20, 8, 12, 15, 30, 40, 50],
}
where sum is a sum of all values of rows where the non null attributes match the id attributes.
For example id1 would match every row of 01/2002 with the same operator (oroperator A) with attribute2 = apple regardlless of attribute 1 (rows 1 and 2) so my total for id 1 for operator A for 01/2022 would be 25.
I tried looping through the rows but it is prone to error and memory greedy. I want to try to use pyspark but cannot find how to distribute the task. I have managed to do it on a row by row basis. Meaning a join on attributes and then groupby + sum. However where I am stuck is that in effect each row has its own set of join keys because of the null constraint (i.e. null in filter matches everything)and I cannot therefore generalise the approach.
I hope I understand the question correctly. Check out this solution:
import pyspark.sql.functions as f
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()
flat_data = {
'year': [2022, 2022, 2022, 2023, 2023, 2023, 2023, 2023, 2023],
'month': [1, 1, 2, 1, 2, 2, 3, 3, 3],
'operator': ['A', 'A', 'B', 'A', 'B', 'B', 'C', 'C', 'C'],
'value': [10, 15, 20, 8, 12, 15, 30, 40, 50],
'attribute1': ['x', 'x', 'y', 'x', 'y', 'z', 'x', 'z', 'x'],
'attribute2': ['apple', 'apple', 'banana', 'apple', 'banana', 'banana', 'apple', 'banana', 'banana'],
'attribute3': ['dog', 'cat', 'dog', 'cat', 'rabbit', 'tutle', 'cat', 'dog', 'dog'],
}
totals= {
'year': [2022, 2022, 2023, 2023, 2023],
'month': [1, 2, 1, 2, 3],
'operator': ['A', 'B', 'A', 'B', 'C'],
'id': ['id1', 'id2', 'id1', 'id2', 'id3'],
'attribute1': [None, 'y', 'x', 'z', 'x'],
'attribute2': ['apple', None, 'apple', 'banana', 'banana'],
}
flat_data_df = spark.createDataFrame(list(zip(*flat_data.values())), list(flat_data.keys()))
totals_df = spark.createDataFrame(list(zip(*totals.values())), list(totals.keys()))
output_df = (
flat_data_df.alias('flat')
.join(
totals_df.alias('total'),
(flat_data_df.year == totals_df.year) &
(flat_data_df.month == totals_df.month) &
(flat_data_df.operator == totals_df.operator) &
((flat_data_df.attribute1 == totals_df.attribute1) | (totals_df.attribute1.isNull())) &
((flat_data_df.attribute2 == totals_df.attribute2) | (totals_df.attribute2.isNull())),
"inner"
)
.select('flat.year', 'flat.month', 'flat.operator', 'total.id', 'flat.value')
.groupBy('year', 'month', 'operator', 'id')
.agg(f.sum('value').alias('sum'))
)
output_df.show()
and the output is:
+----+-----+--------+---+---+
|year|month|operator| id|sum|
+----+-----+--------+---+---+
|2022| 1| A|id1| 25|
|2022| 2| B|id2| 20|
|2023| 1| A|id1| 8|
|2023| 2| B|id2| 15|
|2023| 3| C|id3| 50|
+----+-----+--------+---+---+