I'm trying to aggregate a dataset and determine if a condition is ever true for a row in the dataset.
Suppose I have a dataset with these values
cust_id | travel_type | distance_travelled |
---|---|---|
1 | car | 10 |
1 | boat | 70 |
2 | car | 15 |
2 | plane | 600 |
3 | boat | 80 |
3 | plane | 100 |
I want to aggregate the dataset for each cust_id and determine if a certain condition was ever true for that id. For example, i want to determine if a customer took a flight with a distance greater than 500. The resulting data set should look like this
cust_id | hadLongDistanceFlight |
---|---|
1 | false |
2 | true |
3 | false |
The way I'm currently doing this is by collecting the results of a when
evaluation to a set and later evaluating if a true value exists in the set. While it works, it makes the code very verbose and difficult to maintain as more conditions are added. Is there a way of doing this cleanly?
My current code looks something like this
myDataset.groupBy("cust_id").agg(
collect_set(
when($"travel_type"==="plane" and $"distance_travelled" > 500), true).as("plane_set")
)
.withColumn("hadLongDistanceFlight", exists($"plane_set", _ === true))
I'm aware there is usually a bool_or
function but it doesn't seem to exist in the Scala spark sql library. I'm using the 3.3.2
version of apache.spark.spark-sql and Scala 2.12.12
it can be written as next
myDataset
.groupBy("cust_id")
.agg(
max(when($"travel_type" === "plane" and $"distance_travelled" > 500, true).otherwise(false)).as("hadLongDistanceFlight")
)