Search code examples
dataframescalaapache-sparkpysparkapache-spark-sql

Aggregation on set of columns in Dataframe using Spark and Scala (get max non-null element of each column using selectExpr)


I have a Dataframe as follows:

val df = Seq(
        ("GUID1",   Some(1),    Some(22),       Some(30),   Some(56)),
        ("GUID1",   Some(4),        None,           Some(35),   Some(52)),
        ("GUID1",   None,       Some(24),       None,       Some(58)),
        ("GUID2",   Some(5),        Some(21),       Some(31),   None)
).toDF("GUID",  "A",    "B",    "C",    "D" )
df.show
+-----+----+----+----+----+
| GUID|   A|   B|   C|   D|
+-----+----+----+----+----+
|GUID1|   1|  22|  30|  56|
|GUID1|   4|null|  35|  52|
|GUID1|null|  24|null|  58|
|GUID2|   5|  21|  31|null|
+-----+----+----+----+----+

This is a simplified Dataframe (in reality, there are more than 30 columns)

The goal is to aggregate such that min / max or some custom values for a set of columns needs to be derived. For example, I want to get max non-null of columns A and B and min of C and D using the below Arrays.

val max_cols = Array("A",   "B")
val min_cols = Array("C",   "D")

val df1 = df.groupBy("GUID").agg(collect_list(struct(max_cols.head, max_cols: _*))).as("values")
            .selectExpr("GUID", "array_max(filter(values, x-> x.c.isNotNull))[c] for (c <- values)")

This line is not working

Expected output is:

+-----+---+---+---+----+
| GUID|  A|  B|  C|   D|
+-----+---+---+---+----+
|GUID1|  4| 24| 30|  52|
|GUID2|  5| 21| 31|null|
+-----+---+---+---+----+

I got a similar link in PySpark (pyspark get latest non-null element of every column in one row), but not able to get it working in Scala.

Any idea how to solve it?


Solution

  • You can create min/max expressions and use it in agg. min/max will ignore nulls unless all values are null by default.

    var exprs = min_cols.map(x => min(x).as(x)) ++ max_cols.map(x => max(x).as(x))
    
    df.groupBy("GUID").agg(exprs.head, exprs.tail: _*).show