I am looking for a way to divide my large spark dataset into groups/batches and process that group of rows in some function. So basically group of rows should be input to my function and output is Unit for me as I dont want to aggregate or update input records but just perform some calculation.
Just to understand, lets say I have following input.
Col1 | Col2 | Col3 |
1 | A | 1 |
1 | B | 2 |
1 | C | 3 |
1 | A | 4 |
1 | A | 5 |
2 | C | 6 |
2 | X | 7 |
2 | X | 8 |
and lets say I need to group by col1 and col2, which will give me following groups
(1,A,1), (1,A,4), (1,A,5) ---> first group
(1, B, 2) ---> Second group
(1, C, 3), (1, C, 6)---> 3rd group
(2, X, 7), (2, X, 8) ---> 4th group
So I want to pass these groups to my function to perform some logic. For now, lets just say I am summing Col3 in that method.(this is not my requirement but lets just assume that I want to do that summation in my separate method). To generate following o/p.
Col1 | Col2 | Col3 |
1 | A | 10 |
1 | B | 2 |
1 | C | 9 |
2 | X | 15 |
How can I achieve this, based on some suggestions, I tried to look at UDAF but couldnt find a way how to use it. Pls Note that my real input dataset is having more than 500 million records. Thanks.
Here a simple example based on your input to get you started:
from pyspark.sql.types import IntegerType
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
data = [
(1, "A", 1),
(1, "B", 2),
(1, "C", 3),
(1, "A", 4),
(1, "A", 5),
(1, "C", 6),
(2, "X", 7),
(2, "X", 8),
df = spark.createDataFrame(data, ["col1", "col2", "col3"])
| 1| A| 1|
| 1| B| 2|
| 1| C| 3|
| 1| A| 4|
| 1| A| 5|
| 1| C| 6|
| 2| X| 7|
| 2| X| 8|
# define your function - pure Python here, no Spark needed
def dummy_f(xs):
return sum(xs)
# apply your function as UDF - needs input function and return type (integer here)
.groupBy(F.col("col1"), F.col("col2"))
.withColumn("col3sum", F.udf(dummy_f, IntegerType())(F.col("col3")))
|col1|col2| col3|col3sum|
| 1| A|[1, 4, 5]| 10|
| 1| B| [2]| 2|
| 1| C| [3, 6]| 9|
| 2| X| [7, 8]| 15|
Aggregating columns as needed for your input function is the key. You can use create_map
to create a dict or collect_list
as shown here.