Search code examples
pythonpysparkdatabricksetliot

How can I make a dot product between two lists in a spark dataframe without using a udf?


I have a set of IoT data that are transformed by azure Databricks in a python job. The Databricks cluster is 13.3 LTS (includes Apache Spark 3.4.1, Scala 2.12) Standard_DS3_v2

I put the messages into a delta table.

From the delta table I retrieve the data and I start confronting the data with itself within a 2 hours frame inside of a for loop and by doing so I am doing a lot of confrontation (n*n).

The problem is that with the udf it works but it is slow, it takes 15 minutes for 1000 messages and I need to stay below 10 minutes for a traffic of 1.5 million messages.

I did the following with the udf:

for row in rows:
    try:
    mse = udf(lambda x : sum( (a - b)*(a - b) for a, b in zip(x[:-1] , row["enc"][:-1]))/450 )
    df_compare = df_compare.withColumn('diff_enc',sqrt(mse(df_entrance_compare.enc)))

where rows is a dataframe with the same schema of df_compare and row["enc"] and df_compare.enc contains a list of 451 elements for each cell:

Column enc
[1.0,2.0,3.0,4.0,...]
[1.0,2.0,3.0,4.0,...]

Is there a smarter quicker way to make this compute faster using spark?

Could be a better Idea to stop using Databricks and use a noslql database and make the computation in a function?


Solution

  • If you can upgrade to Spark 3.5.0 you would be able to use reduce, which allows you to express mean square error calculation with pure Spark functions, no UDF:

    df = spark.sql("select array(1,2,2,3,3,4) as x, array(1,2,3,4,5,6) as enc")
    df.withColumn("mse", reduce(arrays_zip(col("x"), col("enc")),
                         lit(0.0),
                         lambda acc, e: acc + pow(e["x"] - e["enc"], 2) / array_size(col("x")))
                 ).show()
    +------------------+------------------+------------------+
    |                 x|               enc|               mse|
    +------------------+------------------+------------------+
    |[1, 2, 2, 3, 3, 4]|[1, 2, 3, 4, 5, 6]|1.6666666666666665|
    +------------------+------------------+------------------+