Search code examples
pythonpandaspysparkdot-product

DOT Product in pyspark?


I have:

df1

+------------------+----------+
|               var|multiplier|
+------------------+----------+
|              var1|         1|
|              var2|         2|
|              var3|         3|
+------------------+----------+

df2

+-------+----------+-----+-----+----------+---------+
|   varA|      varB| varC| var1|      var2|     var3|
+-------+----------+-----+-----+--------------------+
|   abcd|       at1|    5|    1|        45|       12|
|   xyzw|       vt1|    7|    1|        23|       17|
+-------+----------+-----------+----------+---------+

Result: df3

+-------+----------+-----+-----+----------+---------+---------------+
|   varA|      varB| varC| var1|      var2|     var3|     sumproduct|
+-------+----------+-----+-----+--------------------+---------------+
|   abcd|       at1|    5|    1|        90|       36|            127|
|   xyzw|       vt1|    7|    1|        46|       51|             98|
+-------+----------+-----------+----------+---------+---------------+

In python, I am able to achieve this by:

df1 = df1.set_index(['var'])
df3 = df2.dot(df1)

Any help on a similar pyspark way to do the same?


Solution

  • lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()#put multiplier into a list
    df3 =(
     df2.withColumn('a1', array('var1',      'var2',     'var3'))#Create an array from df2
     .withColumn('a2', array([F.lit(x) for x in lst]))#Insert array from df1
     .withColumn('a1',expr("transform(a1, (x,i)->a2[i]*x)"))#Compute dot product
     .select('varA','varB','varC','a1', *[F.col('a1')[i].alias(f'var{str(i+1)}') for i in range(3)])#Expand a1 back to original var columns
     .select('*', expr("aggregate(a1,cast(0 as bigint), (x,i) -> x+i)").alias('sumproduct'))#sumproduct
     .drop('a1','a2')
     )
    

    df3.show()

    +----+----+----+----+----+----+----------+
    |varA|varB|varC|var1|var2|var3|sumproduct|
    +----+----+----+----+----+----+----------+
    |abcd| at1|   5|   1|  90|  36|       127|
    |xyzw| vt1|   7|   1|  46|  51|        98|
    +----+----+----+----+----+----+----------+
    

    Remember if all you need is the dot product, udf is a possibility. We can use numpy which is very good at such stuff

    import numpy as np
    lst=df1.select("multiplier").rdd.flatMap(lambda x: x).collect()
    dot_array = udf(lambda x,y: int(np.dot(x,y)), IntegerType())
    df2.withColumn("dotproduct",dot_array(array('var1',      'var2',     'var3'),array([F.lit(x) for x in lst]))).show()
    
    +----+----+----+----+----+----+----------+
    |varA|varB|varC|var1|var2|var3|dotproduct|
    +----+----+----+----+----+----+----------+
    |abcd| at1|   5|   1|  45|  12|       127|
    |xyzw| vt1|   7|   1|  23|  17|        98|
    +----+----+----+----+----+----+----------+