Search code examples
apache-sparkpysparkapache-spark-sqlaggregation

Aggregating columns conditionally with pyspark?


I have following dataset. I want to group all variables and split the data based on the conditions below.

However, I am getting error when I tried the code below.

CUST_ID NAME    GENDER  AGE
id_01   MONEY   F   43
id_02   BAKER   F   32
id_03   VOICE   M   31
id_04   TIME    M   56
id_05   TIME    F   24
id_06   TALENT  F   28
id_07   ISLAND  F   21
id_08   ISLAND  F   27
id_09   TUME    F   24
id_10   TIME    F   75
id_11   SKY M   35
id_12   VOICE   M   70



    from pyspark.sql.functions import *

    df.groupBy("CUST_ID", "NAME", "GENDER", "AGE").agg(
       CUST_ID.count AS TOTAL
       SUM(WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'M') THEN COUNT(CUST_ID) ELSE 0 END AS "M18-34")
       SUM(WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'F') THEN COUNT(CUST_ID) ELSE 0 END AS "F18-34")
       SUM(WHEN ((AGE >= 18 AND AGE <= 34 THEN COUNT(CUST_ID) ELSE 0 END AS "18-34")
       SUM(WHEN ((AGE >= 25 AND AGE <= 54 THEN COUNT(CUST_ID) ELSE 0 END AS "25-54")
       SUM(WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'F') THEN COUNT(CUST_ID) ELSE 0 END AS "F25-54")
       SUM(WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'M') THEN COUNT(CUST_ID) ELSE 0 END AS "M25-54")   
    )

I would appreciate your help/suggestions

Thanks in advance


Solution

  • Your code is neither valid pyspark nor valid Spark SQL. There are so many syntax problems. I attempted to fix them below, not sure if that's what you wanted. If you have so many SQL-like statements, it's better to use Spark SQL directly rather than the pyspark API:

    df.createOrReplaceTempView('df')
    result = spark.sql("""
    SELECT NAME,
           COUNT(CUST_ID) AS TOTAL,
           SUM(CASE WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'M') THEN 1 ELSE 0 END) AS `M18-34`,
           SUM(CASE WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'F') THEN 1 ELSE 0 END) AS `F18-34`,
           SUM(CASE WHEN (AGE >= 18 AND AGE <= 34) THEN 1 ELSE 0 END) AS `18-34`,
           SUM(CASE WHEN (AGE >= 25 AND AGE <= 54) THEN 1 ELSE 0 END) AS `25-54`,
           SUM(CASE WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'F') THEN 1 ELSE 0 END) AS `F25-54`,
           SUM(CASE WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'M') THEN 1 ELSE 0 END) AS `M25-54` 
    FROM df
    GROUP BY NAME
    """)
    
    result.show()
    +------+-----+------+------+-----+-----+------+------+
    |  NAME|TOTAL|M18-34|F18-34|18-34|25-54|F25-54|M25-54|
    +------+-----+------+------+-----+-----+------+------+
    |ISLAND|    2|     0|     2|    2|    1|     1|     0|
    | MONEY|    1|     0|     0|    0|    1|     1|     0|
    |  TIME|    3|     0|     1|    1|    0|     0|     0|
    | VOICE|    2|     1|     0|    1|    1|     0|     1|
    |  TUME|    1|     0|     1|    1|    0|     0|     0|
    | BAKER|    1|     0|     1|    1|    1|     1|     0|
    |TALENT|    1|     0|     1|    1|    1|     1|     0|
    |   SKY|    1|     0|     0|    0|    1|     0|     1|
    +------+-----+------+------+-----+-----+------+------+
    

    If you want a pyspark solution, here's an example of how to do it for the first column. You can work out the rest straightforwardly.

    import pyspark.sql.functions as F
    result = df.groupBy('Name').agg(
        F.count('CUST_ID').alias('TOTAL'),
        F.count(F.when(F.expr("(AGE >= 18 AND AGE <= 34) AND GENDER = 'M'"), 1)).alias("M18-34")
    )
    
    result.show()
    +------+-----+------+
    |  Name|TOTAL|M18-34|
    +------+-----+------+
    |ISLAND|    2|     0|
    | MONEY|    1|     0|
    |  TIME|    3|     0|
    | VOICE|    2|     1|
    |  TUME|    1|     0|
    | BAKER|    1|     0|
    |TALENT|    1|     0|
    |   SKY|    1|     0|
    +------+-----+------+