Search code examples
amazon-web-servicesscalaapache-spark

How to aggregate the columns dynamically in spark scala?


I have newly started working in spark-scala. I have a requirement where in I need to find the sum for few columns within a case statement. I've written the corresponding spark-sql code but unable to implement the same in spark-scala dynamically. Below is what I'm trying to achieve -

SQL Code-

Select  col_A,
        round(case when sum(amt_M)   <> 0.0 then sum(amt_M) 
                   when sum(amt_N)   <> 0.0 then sum(amt_N)
                   when sum(amt_P)   <> 0.0 then sum(amt_P) 
              end,1) as pct 
from table_T1
group by col_A

The use case is to get certain columns from a variable to implement the case-statement logic as above dynamically. Having said that, currently considering there are 3 columns however, that number could increase later on.

Below is the code I tried to implement in spark-scala -

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection._

val df = spark.table("database.table_T1")

val cols = "amt_M,amt_N,amt_P"

val aggCols = cols.split(",").toSeq

val sums = aggCols.map(colName => when(round(sum(colName).cast(DoubleType),1) =!= 0.0,sum(colName).cast(DoubleType).alias("sum_"+colName)))

val df2 = df.groupBy(col("col_A")).agg(sums.head, sums.tail:_*)

However, this is not giving the desired results. Please help me on this.

Input Data

+--------+--------------------+---------------------+----------------------+
|col_A   |amt_M               |amt_N                |amt_P                 |
+--------+--------------------+---------------------+----------------------+
|5C-SVS-1|0.0                 |0.04064912622009295  |1.6256888829356116E-4 |
|5C-SVS-1|0.0                 |0.026542159153759487 |8.574900251977566E-4  |
|5C-SVS-1|0.0                 |5.703894148377958E-5 |1.0745888408402782E-7 |
|5C-SVS-1|0.0                 |0.0                  |4.514561031069833E-4  |
|5C-SVS-1|0.0                 |0.011794053124022862 |0.0020388259536434656 |
|5C-SVS-1|0.0                 |7.55793849084569E-4  |0.0017105736019335327 |
|5C-SVS-1|0.0                 |0.019303776946698548 |2.240625765755109E-5  |
|5C-SVS-1|0.0                 |-8.028117213883126E-6|-2.1979360825171534E-6|
|5C-SVS-1|0.001940948839163001|0.029163686986129422 |0.09505621692309557   |
|5C-SVS-1|0.0                 |2.515835289984397E-7 |1.1486227577926157E-8 |
|5C-SVS-1|0.0                 |0.007844299114837874 |9.974187712854785E-4  |
|5C-SVS-1|0.0                 |5.033123682586349E-4 |1.3644443189731007E-4 |
|5C-SVS-1|0.0                 |0.026331681277001386 |6.022434166108063E-4  |
|5C-SVS-1|0.0                 |8.098023638080503E-6 |1.0                   |
|5C-SVS-1|0.0                 |0.03655893437209876  |0.003113370686486882  |
|5C-SVS-1|0.0                 |0.01409363925733864  |6.239415097038338E-4  |
|5C-SVS-1|0.0                 |0.02171856350557304  |0.0                   |
|5C-SVS-1|0.008435341548288601|0.03347191686227869  |0.35221710556006247   |
|5C-SVS-1|0.0                 |-2.547132732700875E-6|-0.13073525789233997  |
|5C-SVS-1|0.006057441518729214|0.024036273783621134 |0.21447606070652467   |
+--------+--------------------+---------------------+----------------------+

Expected Output

+--------+---+
|   col_A|pct|
+--------+---+
|5C-SVS-1|1.0|
+--------+---+

Thanks


Solution

  • I solved the requirement by implementing the below method -

    import org.apache.spark.sql.types._
    
    def getSumCols(columnList: List[String]): Column = {
    
    // Storing the value for the 1st index 
    
        var conditionColumn: Column = when(sum(col(columnList(0)).cast(DoubleType)) =!= 0.0, sum(col(columnList(0)).cast(DoubleType)))
    
    // Iterating through the 2nd element till end and appending to existing variable created in the 1st step
    
        for(c <- 1 to columnList.length -1){
            conditionColumn = conditionColumn.when( sum(col(columnList(c)).cast(DoubleType)) =!= 0.0, sum(col(columnList(c)).cast(DoubleType)) )
        }
        round(conditionColumn,1)
    }
    

    Now once this is being called over during the aggregation as below -

    val cols = "amt_M,amt_N,amt_P"
    
    val colList = cols.split(",").toList
    
    val conditionColumn: Column = getSumCols(colList)
    
    val df1 = df.groupBy("col_A").agg(conditionColumn.alias("pct"))