Search code examples
pythonapache-sparkpysparkhistogramrdd

Histogram of grouped data in PySpark


I have data consisting of a date-time, IDs, and velocity, and I'm hoping to get histogram data (start/end points and counts) of velocity for each ID using PySpark. Sample data:

df = spark.createDataFrame(
    [
        ("2023-06-01 07:09:17", "abc", 4.5),
        ("2023-06-01 07:09:18", "abc", 9.1),
        ("2023-06-01 07:09:19", "abc", 3.2),
        ("2023-06-01 07:10:06", "ddc", 5.1),
        ("2023-06-01 07:09:07", "ddc", 3.6),
        ("2023-06-01 07:09:08", "ddc", 2.6)
    ],
    ["date_time", "id", "velocity"]
)

I'm not too picky about how the output is formatted. Initially I was histograms using Spark's rdd.histogram(bins) function, but this was over all the velocity values (with no grouping). This code was:

df.filter(col("velocity").isNotNull()).rdd.histogram(list(range(0, 100, 1)))

However, I cannot figure out how to do this for grouped data. I've tried these two things:

# This throws an error: 'GroupedData' object has no attribute 'rdd'
df.filter(col("velocity").isNotNull()).groupBy("id").rdd.histogram(list(range(0, 100, 1)))

# This throws a much longer error, but ends with: TypeError: 'str' object is not callable
# I think this has to do with the rdd.groupBy method
df.filter(col("velocity").isNotNull()).rdd.groupBy("id").histogram(list(range(0, 100, 1)))

# This throws a long error, with this TypeError: TypeError: '>' not supported between instances of 'tuple' and 'int'
df.filter(col("velocity").isNotNull()).select("id", "velocity").rdd.groupByKey().histogram(list(range(0, 100, 1)))

Solution

  • You could group by id and simply count the number of velocity values that falls in all the intervals that you are interested in. It would go as follows:

    result = df.filter(col("velocity").isNotNull())\
               .groupBy("id")\
               .agg( *[sum(
                   when((col("velocity") >= i) & (col("velocity") < (i+1)), 1)
                       .otherwise(0)
               ).alias(f"between_{i}_{i+1}") for i in range(10)])
    
    result.show()
    
    +---+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+
    | id|between_0_1|between_1_2|between_2_3|between_3_4|between_4_5|between_5_6|between_6_7|between_7_8|between_8_9|between_9_10|
    +---+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+
    |abc|          0|          0|          0|          1|          1|          0|          0|          0|          0|           1|
    |ddc|          0|          0|          1|          1|          0|          1|          0|          0|          0|           0|
    +---+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+
    

    And if you prefer a simple array:

    result\
        .select("id", array(
            [col(f"between_{i}_{i+1}") for i in range(10)]
        ).alias("histogram") )\
        .show(truncate = False)
    
    +---+------------------------------+
    |id |histogram                     |
    +---+------------------------------+
    |abc|[0, 0, 0, 1, 1, 0, 0, 0, 0, 1]|
    |ddc|[0, 0, 1, 1, 0, 1, 0, 0, 0, 0]|
    +---+------------------------------+