Search code examples
scalaperformanceapache-sparkapache-spark-sqlprocessing-efficiency

Spark Scala: Querying same table multiple times


I am trying to query multiple columns from the same table (bigTable) to generate some aggregated columns (column1_sum, column2_sum, column3_count). In the end, I join all the columns together to form one table.

Code below

val t1 = bigTable
            .filter($"column10" === value1)
            .groupBy("key1","key2")
            .agg(sum("column1") as "column1_sum")

val t2 = bigTable
            .filter($"column11"===1)
            .filter($"column10" === value1)
            .groupBy("key1","key2")
            .agg(sum("column2") as "column2_sum")

val t3 = bigTable
            .filter($"column10" === value3)
            .groupBy("key1","key2")
            .agg(countDistinct("column3") as "column3_count")

tAll
            .join(t1,Seq("key1","key2"),"left_outer")
            .join(t2,Seq("key1","key2"),"left_outer")
            .join(t3,Seq("key1","key2"),"left_outer")

Issues with the above code

bigTable is a huge table (it runs into millions of rows). So, querying it multiple times is not efficient. The query is taking a lot of time to run.

Any ideas on how I could achieve the same output in a more efficient way? Is there a way to query the bigTable lesser number of times?

Thanks a lot in advance.


Solution

  • The simplest improvement is to perform only as single aggregation, where predicated is pushed into CASE ... WHEN ... block, and replace countDistinct with an approximate equivalent

    tAll
      .groupBy("key1","key2")
      .agg(
        sum(
          when($"column10" === "value1", $"column1")
        ).as("column1_sum"),
        sum(
          when($"column10" === "value1" and $"column11" === 1, $"column2")
        ).as("column2_sum"),
        approx_count_distinct(
          when($"column10" === "value3", $"column3")
        ).as("column3_count"))
      .join(tAll, Seq("key1", "key2"), "right_outer"))
    

    Depending on the functions used and the a prori knowledge about data distribution you can also try to replace aggregation with window functions with similar CASE ... WHEN ... logic

    import org.apache.spark.sql.expressions.Window
    
    val w = Window
     .partitionBy("key1", "key2")
     .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    
    tAll
      .withColumn(
        "column1_sum", 
        sum(when($"column10" === "value1", $"column1")).over(w))
     ...
    

    but it is often a less stable approach.

    You should also consider bucketing bigTable using grouping columns:

    val n: Int = ???  // Number of buckets
    bigTable.write.bucketBy(n, "key1", "key2").saveAsTable("big_table_clustered")
    
    val bigTableClustered = spark.table("big_table_clustered")