Search code examples
apache-sparkpysparkapache-spark-sqlcombinationsaggregation

Create rows for 0 values when aggregating all combinations of several columns


Using the example in this question, how do I create rows of 0 count when aggregating all possible combinations? When using cube, rows of 0 do not populate.

This is the code and output:

df.cube($"x", $"y").count.show

// +----+----+-----+     
// |   x|   y|count|
// +----+----+-----+
// |null|   1|    1|   <- count of records where y = 1
// |null|   2|    3|   <- count of records where y = 2
// | foo|null|    2|   <- count of records where x = foo
// | bar|   2|    2|   <- count of records where x = bar AND y = 2
// | foo|   1|    1|   <- count of records where x = foo AND y = 1
// | foo|   2|    1|   <- count of records where x = foo AND y = 2
// |null|null|    4|   <- total count of records
// | bar|null|    2|   <- count of records where x = bar
// +----+----+-----+

But this is the desired output (added row 4).

// +----+----+-----+     
// |   x|   y|count|
// +----+----+-----+
// |null|   1|    1|   <- count of records where y = 1
// |null|   2|    3|   <- count of records where y = 2
// | foo|null|    2|   <- count of records where x = foo
// | bar|   1|    0|   <- count of records where x = bar AND y = 1
// | bar|   2|    2|   <- count of records where x = bar AND y = 2
// | foo|   1|    1|   <- count of records where x = foo AND y = 1
// | foo|   2|    1|   <- count of records where x = foo AND y = 2
// |null|null|    4|   <- total count of records
// | bar|null|    2|   <- count of records where x = bar
// +----+----+-----+

Is there another function that could do that?


Solution

  • I agree that crossJoin here is the correct approach. But I think afterwards it may be a bit more versatile to use a join instead of a union and groupBy. Especially if there are more aggregations than one count.

    from pyspark.sql import functions as F
    df = spark.createDataFrame(
        [('foo', 1),
         ('foo', 2),
         ('bar', 2),
         ('bar', 2)],
        ['x', 'y'])
    
    df_cartesian = df.select('x').distinct().crossJoin(df.select("y").distinct())
    df_cubed = df.cube('x', 'y').count()
    df_cubed.join(df_cartesian, ['x', 'y'], 'full').fillna(0, ['count']).show()
    
    # +----+----+-----+
    # |   x|   y|count|
    # +----+----+-----+
    # |null|null|    4|
    # |null|   1|    1|
    # |null|   2|    3|
    # | bar|null|    2|
    # | bar|   1|    0|
    # | bar|   2|    2|
    # | foo|null|    2|
    # | foo|   1|    1|
    # | foo|   2|    1|
    # +----+----+-----+