Search code examples
scalaapache-sparkapache-spark-sql

Determine if a condition is ever true in an aggregated dataset with Scala spark sql library


I'm trying to aggregate a dataset and determine if a condition is ever true for a row in the dataset.

Suppose I have a dataset with these values

cust_id travel_type distance_travelled
1 car 10
1 boat 70
2 car 15
2 plane 600
3 boat 80
3 plane 100

I want to aggregate the dataset for each cust_id and determine if a certain condition was ever true for that id. For example, i want to determine if a customer took a flight with a distance greater than 500. The resulting data set should look like this

cust_id hadLongDistanceFlight
1 false
2 true
3 false

The way I'm currently doing this is by collecting the results of a when evaluation to a set and later evaluating if a true value exists in the set. While it works, it makes the code very verbose and difficult to maintain as more conditions are added. Is there a way of doing this cleanly?

My current code looks something like this

myDataset.groupBy("cust_id").agg(
collect_set(
    when($"travel_type"==="plane" and $"distance_travelled" > 500), true).as("plane_set")
)
.withColumn("hadLongDistanceFlight", exists($"plane_set", _ === true))

I'm aware there is usually a bool_or function but it doesn't seem to exist in the Scala spark sql library. I'm using the 3.3.2 version of apache.spark.spark-sql and Scala 2.12.12


Solution

  • it can be written as next

    myDataset
        .groupBy("cust_id")
        .agg(
          max(when($"travel_type" === "plane" and $"distance_travelled" > 500, true).otherwise(false)).as("hadLongDistanceFlight")
        )