Search code examples
apache-sparkpysparkgroup-byaggregategrouping

Pyspark Group by Id only consecutive occurrences


I have this dataframe:

Id Minute identifier
aa 1 1
aa 1 2
aa 1 3
aa 2 (ignore) 4
aa 1 5
aa 1 6
bb 1 7
bb 1 8
bb 5 (ignore) 9
bb 1 10

My desired output is grouped by "Id", but sum consecutive "Minute" (only when they are = 1):

Id Minute
aa 3
aa 2
bb 2
bb 1

Solution

  • It can be done using window functions. They must be used twice: to create helper column for grouping and then to create groups. Then, group rows according to created groups.

    Input:

    from pyspark.sql import functions as F, Window as W
    df = spark.createDataFrame(
        [('aa', 1, 1),
         ('aa', 1, 2),
         ('aa', 1, 3),
         ('aa', 2, 4),
         ('aa', 1, 5),
         ('aa', 1, 6),
         ('bb', 1, 7),
         ('bb', 1, 8),
         ('bb', 5, 9),
         ('bb', 1, 10)],
        ['Id', 'Minute', 'identifier'])
    

    Script:

    w = W.partitionBy('Id').orderBy('identifier')
    df = df.withColumn('_flg', F.coalesce(F.when(F.lag("Minute").over(w) != F.col("Minute"), 1), F.lit(0)))
    df = df.withColumn('_grp', F.sum('_flg').over(w))
    df = (df
        .filter(F.col('Minute') == 1)
        .groupBy('Id', '_grp')
        .agg(F.count(F.lit(1)).alias('Minute'))
        .drop('_grp')
    )
    df.show()
    # +---+------+
    # | Id|Minute|
    # +---+------+
    # | aa|     3|
    # | aa|     2|
    # | bb|     2|
    # | bb|     1|
    # +---+------+