Search code examples
scalaapache-sparkapache-spark-sql

Sum column based on another columns values


I have a dataframe which represents grayscale image

+---+---+-----+
|  w|  h|color|
+---+---+-----+
|  0|  0|255.0|
|  0|  1|255.0|
|  0|  2|255.0|
|  0|  3|255.0|
|  0|  4|255.0|
|  0|  5|255.0|
|  0|  6|255.0|
|  0|  7|255.0|
|  0|  8|255.0|
|  0|  9|255.0|
|  0| 10|255.0|
|  0| 11|255.0|
|  1|  0|255.0|
|  1|  1|255.0|
|  1|  2|255.0|
|  1|  3|255.0|
|  1|  4|255.0|
|  1|  5|255.0|
|  1|  6|255.0|
|  1|  7|255.0|
+---+---+-----+
top 20 rows

For each row I need to sum "color" when values from "w" and "h" are in range from current value to current value plus a number.

To better understading, possible solution would look like this:

val windowW = Window.rangeBetween(Window.currentRow, Window.currentRow + num1)
val windowH = Window.rangeBetween(Window.currentRow, Window.currentRow + num2)

df.withColumn("color_sum", sum(col("color")).over(col("w").windowW and col("h").windowH))

Could you please give me some hints how to achieve this calculation?

Expected output for the very first row:

+---+---+-----+----------+
|  w|  h|color|sum(color)|
+---+---+-----+----------+
|  0|  0|255.0|      1020|
+---+---+-----+----------+

Where num1 and num2 are both equals 1.

That means sum is taken from rows:

(0, 0), (0, 1), (1, 0), (1, 1)

For row (1, 1) sum would be taken from rows (1, 1), (1, 2), (2, 1), (2, 2).


Solution

  • You can't use window functions to apply a function on 2-dimensional data with these constraints.

    You can use the join method to find the rows:

    df.as("df1")
      .join(df.as("df2"),
        ((col("df1.w") - col("df2.w") <= 0) && col("df1.w") - col("df2.w") >= -1) &&
          ((col("df1.h") - col("df2.h") <= 0) && col("df1.h") - col("df2.h") >= -1),
        "inner"
      )
      .groupBy("df1.w", "df1.h")
      .agg(min("df1.color") as "color", sum("df2.color") as "sum")
      .orderBy("w", "h")
      .show()