I am trying to query multiple columns from the same table (bigTable) to generate some aggregated columns (column1_sum, column2_sum, column3_count). In the end, I join all the columns together to form one table.
Code below
val t1 = bigTable
.filter($"column10" === value1)
.groupBy("key1","key2")
.agg(sum("column1") as "column1_sum")
val t2 = bigTable
.filter($"column11"===1)
.filter($"column10" === value1)
.groupBy("key1","key2")
.agg(sum("column2") as "column2_sum")
val t3 = bigTable
.filter($"column10" === value3)
.groupBy("key1","key2")
.agg(countDistinct("column3") as "column3_count")
tAll
.join(t1,Seq("key1","key2"),"left_outer")
.join(t2,Seq("key1","key2"),"left_outer")
.join(t3,Seq("key1","key2"),"left_outer")
Issues with the above code
bigTable is a huge table (it runs into millions of rows). So, querying it multiple times is not efficient. The query is taking a lot of time to run.
Any ideas on how I could achieve the same output in a more efficient way? Is there a way to query the bigTable lesser number of times?
Thanks a lot in advance.
The simplest improvement is to perform only as single aggregation, where predicated is pushed into CASE ... WHEN ...
block, and replace countDistinct
with an approximate equivalent
tAll
.groupBy("key1","key2")
.agg(
sum(
when($"column10" === "value1", $"column1")
).as("column1_sum"),
sum(
when($"column10" === "value1" and $"column11" === 1, $"column2")
).as("column2_sum"),
approx_count_distinct(
when($"column10" === "value3", $"column3")
).as("column3_count"))
.join(tAll, Seq("key1", "key2"), "right_outer"))
Depending on the functions used and the a prori knowledge about data distribution you can also try to replace aggregation with window functions with similar CASE ... WHEN ...
logic
import org.apache.spark.sql.expressions.Window
val w = Window
.partitionBy("key1", "key2")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
tAll
.withColumn(
"column1_sum",
sum(when($"column10" === "value1", $"column1")).over(w))
...
but it is often a less stable approach.
You should also consider bucketing bigTable
using grouping columns:
val n: Int = ??? // Number of buckets
bigTable.write.bucketBy(n, "key1", "key2").saveAsTable("big_table_clustered")
val bigTableClustered = spark.table("big_table_clustered")