Search code examples
dataframepysparkmultiplication

How to multiply all the columns of the dataframe in pySpark with other single column


I have data of all the Months from Jan to Dec for population for particular year and I have one column say "Constant" and I need to multiply that constant column value with all the columns data from Jan to Dec in spark. For Example, I have following data :

        JAN FEB MAR...DEC Constant
City1   160 158 253   391   12
City2   212 27  362   512   34
City3   90  150 145   274   56

After multiplication, I want new/replace dataframe with values :

        JAN     FEB MAR ....DEC
City1   192    1896 3036    1656
City2   7208   918  12308   8092
City3   504    280  8120    2464

I am able to do it by one column at a time with the code :

Df.select("JAN","CONSTANT").withColumn("JAN",col('JAN') * col ('CONSTANT')).show()

Is there any function/loop where i can get the entire column multiplication and new dataframe values all months?


Solution

  • You could express your logic using a struct of structs. Structs are basically the same as a column in higher order, so we can assign them a name, multiply them by constant, and then select them using columnname.*. This way you dont have to do withColumn 12 times. You could put all your months in listofmonths.

    df.show() #sampledata
    #+-----+---+---+---+---+--------+
    #| City|JAN|FEB|MAR|DEC|Constant|
    #+-----+---+---+---+---+--------+
    #|City1|160|158|253|391|      12|
    #|City2|212| 27|362|512|      34|
    #|City3| 90|150|145|274|      56|
    #+-----+---+---+---+---+--------+
    
    
    
    listofmonths=['JAN','FEB','MAR','DEC']
    
    from pyspark.sql import functions as F
    df.withColumn("arr", F.struct(*[(F.col(x)*F.col('Constant')).alias(x) for x in listofmonths]))\
      .select("City","arr.*")\
      .show()
    
    #+-----+----+----+-----+-----+
    #| City| JAN| FEB|  MAR|  DEC|
    #+-----+----+----+-----+-----+
    #|City1|1920|1896| 3036| 4692|
    #|City2|7208| 918|12308|17408|
    #|City3|5040|8400| 8120|15344|
    #+-----+----+----+-----+-----+
    

    You could also just use df.columns instead of listofmonths like this:

    from pyspark.sql import functions as F
    df.withColumn("arr", F.struct(*[(F.col(x)*F.col('Constant')).alias(x) for x in df.columns if x!='City' and x!='Constant']))\
      .select("City","arr.*")\
      .show()