Search code examples
apache-sparkpysparkapache-spark-sqlaggregategrouping

How can I count different groups and group them into one column in PySpark?


In this example, I have the following dataframe:

client_id   rule_1   rule_2   rule_3   rule_4   rule_5
    1         1        0         1       0        0
    2         0        1         0       0        0
    3         0        1         1       1        0
    4         1        0         1       1        1

It shows the client_id and if he's obeying a certain rule or not.

How would I be able to count the number of clients that obey each rule, in a way that I can show all information in one dataframe?

rule    obeys    count
rule_1    0      23852
rule_1    1      95102
rule_2    0      12942
rule_2    1      45884
rule_3    0      29319
rule_3    1       9238
rule_4    0      55321
rule_4    1      23013
rule_5    0      96842
rule_5    1      86739

Solution

  • The operation of moving column names to rows is called unpivoting. In Spark, it is done using stack function.

    Input:

    from pyspark.sql import functions as F
    df = spark.createDataFrame(
        [(1, 1, 0, 1, 0, 0),
         (2, 0, 1, 0, 0, 0),
         (3, 0, 1, 1, 1, 0),
         (4, 1, 0, 1, 1, 1)],
        ["client_id", "rule_1", "rule_2", "rule_3", "rule_4", "rule_5"])
    

    Script:

    to_unpivot = [f"\'{c}\', `{c}`" for c in df.columns if c != "client_id"]
    stack_str = ",".join(to_unpivot)
    df = (df
        .select(F.expr(f"stack({len(to_unpivot)}, {stack_str}) as (rule, obeys)"))
        .groupBy("rule", "obeys")
        .count()
    )
    df.show()
    # +------+-----+-----+
    # |  rule|obeys|count|
    # +------+-----+-----+
    # |rule_1|    1|    2|
    # |rule_2|    1|    2|
    # |rule_1|    0|    2|
    # |rule_3|    1|    3|
    # |rule_2|    0|    2|
    # |rule_4|    0|    2|
    # |rule_3|    0|    1|
    # |rule_5|    0|    3|
    # |rule_5|    1|    1|
    # |rule_4|    1|    2|
    # +------+-----+-----+