Search code examples
apache-sparkpysparkgroup-bygrouping

How to partition by groups of N in PySpark


I have the following data frame:

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.window import Window
import pyspark.sql.functions as F

data = [
    ( 1, "AAA", "BBB", "CCC", "DDD", "desktop"),
    ( 2, "AAA", "BBB", "CCC", "DDD", "desktop"),
    ( 3, "AAA", "BBB", "CCC", "DDD", "mobile"),
    ( 4, "AAA", "BBB", "CCC", "DDD", "desktop"),
    ( 5, "AAA", "BBB", "CCC", "DDD", "mobile"),
    ( 6, "AAA", "BBB", "CCC", "DDD", "desktop"),
    ( 7, "AAA", "BBB", "CCC", "DDD", "desktop"),
    ( 8, "AAA", "BBB", "CCC", "DDD", "desktop"),
    ( 9, "AAA", "BBB", "CCC", "DDD", "desktop"),
    (10, "AAA", "BBB", "CCC", "DDD", "mobile"),
    (11, "AAA", "BBB", "CCC", "DDD", "desktop"),
    (12, "EEE", "FFF", "GGG", "HHH", "desktop"),
    (13, "EEE", "FFF", "GGG", "HHH", "mobile"),
    (14, "EEE", "FFF", "GGG", "HHH", "desktop"),
    (15, "EEE", "FFF", "GGG", "HHH", "mobile"),
    (16, "EEE", "FFF", "GGG", "HHH", "desktop"),
    (17, "EEE", "FFF", "GGG", "HHH", "desktop"),
    (18, "EEE", "FFF", "GGG", "HHH", "desktop"),
    (19, "III", "JJJ", "KKK", "LLL", "desktop"),
    (20, "III", "JJJ", "KKK", "LLL", "mobile"),
    (21, "III", "JJJ", "KKK", "LLL", "desktop"),
    (22, "III", "JJJ", "KKK", "LLL", "desktop"),
    (23, "III", "JJJ", "KKK", "LLL", "mobile"),
    (24, "III", "JJJ", "KKK", "LLL", "desktop"),
    (25, "III", "JJJ", "KKK", "LLL", "desktop"),
    (26, "III", "JJJ", "KKK", "LLL", "desktop"),
    (27, "III", "JJJ", "KKK", "LLL", "desktop"),
    (28, "III", "JJJ", "KKK", "LLL", "desktop"),
    (29, "III", "JJJ", "KKK", "LLL", "desktop"),
    (30, "III", "JJJ", "KKK", "LLL", "mobile")
]

schema = StructType([ \
    StructField("id", IntegerType(),True),
    StructField("text", StringType(),True),
    StructField("title", StringType(),True),
    StructField("target_url", StringType(), True),
    StructField("display_domain", StringType(), True),
    StructField("device", StringType(), True)
])
 
df = spark.createDataFrame(data=data,schema=schema)

columns = [
    "text",
    "title",
    "target_url",
    "display_domain"
]

windowSpecByPartition = (
    Window.partitionBy(
       columns 
    ).orderBy("id")
)

overall_row_number_df = df.withColumn(
    "overall_row_number",
    F.row_number().over(windowSpecByPartition)
)

I want to partition into groups of 5 per group even if the group is incomplete.

What I am expecting to have is the next table:

id text title target_url display_domain device group_id
1 AAA BBB CCC DDD desktop 1
2 AAA BBB CCC DDD desktop 1
3 AAA BBB CCC DDD mobile 1
4 AAA BBB CCC DDD desktop 1
5 AAA BBB CCC DDD mobile 1
6 AAA BBB CCC DDD desktop 2
7 AAA BBB CCC DDD desktop 2
8 AAA BBB CCC DDD desktop 2
9 AAA BBB CCC DDD desktop 2
10 AAA BBB CCC DDD mobile 2
11 AAA BBB CCC DDD desktop 3
12 EEE FFF GGG HHH desktop 4
13 EEE FFF GGG HHH mobile 4
14 EEE FFF GGG HHH desktop 4
15 EEE FFF GGG HHH mobile 4
16 EEE FFF GGG HHH desktop 4
17 EEE FFF GGG HHH desktop 5
18 EEE FFF GGG HHH desktop 5
19 III JJJ KKK LLL desktop 6
20 III JJJ KKK LLL mobile 6
21 III JJJ KKK LLL desktop 6
22 III JJJ KKK LLL desktop 6
23 III JJJ KKK LLL mobile 6
24 III JJJ KKK LLL desktop 7
25 III JJJ KKK LLL desktop 7
26 III JJJ KKK LLL desktop 7
27 III JJJ KKK LLL desktop 7
28 III JJJ KKK LLL desktop 7
29 III JJJ KKK LLL desktop 8
30 III JJJ KKK LLL mobile 8

In the end, I need to end up with 8 groups of data, this is really similar to the behavior of the in_groups_of of Rails.


Solution

  • You can divide row_number by 5 then round up to get a unique id for each group within a partition then partition by columns and the additional id:

    (
    overall_row_number_df.withColumn("sub_group", F.ceil((F.col("overall_row_number") / 5)))
        .groupBy(columns + ["sub_group"]).agg(F.collect_list("device").alias("devices"))
    ).show(truncate=False)
    

    gives you:

    text title target_url display_domain sub_group devices
    AAA BBB CCC DDD 1 [desktop, desktop, mobile, desktop, mobile]
    AAA BBB CCC DDD 2 [desktop, desktop, desktop, desktop, mobile]
    AAA BBB CCC DDD 3 [desktop]
    EEE FFF GGG HHH 1 [desktop, mobile, desktop, mobile, desktop]
    EEE FFF GGG HHH 2 [desktop, desktop]
    III JJJ KKK LLL 1 [desktop, mobile, desktop, desktop, mobile]
    III JJJ KKK LLL 2 [desktop, desktop, desktop, desktop, desktop]
    III JJJ KKK LLL 3 [desktop, mobile]