Search code examples
pythonpandasscalaapache-sparkpyspark

How can values in a Spark array column be efficiently replaced with values from a Pandas data frame?


I have a Spark data frame that contains a column of arrays with product ids from sold baskets.

import pandas as pd 
import pyspark.sql.types as T
from pyspark.sql import functions as F

df_baskets = spark.createDataFrame(
    [(1, ["546", "689", "946"]), (2, ["546", "799"] )],
    ("case_id","basket")
)

df_baskets.show()

#+-------+---------------+
#|case_id|         basket|
#+-------+---------------+
#|      1|[546, 689, 946]|
#|      2|     [546, 799]|
#+-------+---------------+

I would like to replace the product ids in each array with new ids given in a pandas data frame.


product_data = pd.DataFrame({
  "product_id": ["546", "689", "946", "799"],
  "new_product_id": ["S12", "S74", "S34", "S56"]
  })

product_data

I was able to replace the values by applying a simple python function to the column that performs a lookup on the python data frame.


def get_new_id(product_id: str) -> str:
  try:
    row = product_data[product_data["product_id"] == product_id]
    return row["new_product_id"].item()
  except ValueError:
    return product_id

apply_get = F.udf(lambda basket: [get_new_id(product) for product in basket], T.ArrayType(T.StringType()))

df_baskets = (
  df_baskets
    .withColumn('basket_renamed', apply_get(F.col('basket')))
)

df_baskets.show()

#+-------+---------------+---------------+
#|case_id|         basket| basket_renamed|
#+-------+---------------+---------------+
#|      1|[546, 689, 946]|[S12, S74, S34]|
#|      2|     [546, 799]|     [S12, S56]|
#+-------+---------------+---------------+

However, this approach has proven to be quite slow in data frames containing several tens of millions of cases. Is there more efficient way to do this replacement (e.g. by using a different data structure than a pandas data frame or a different method)?


Solution

  • You could explode your original data and join on product_data (after converting it to a spark frame)

    (
        df_baskets
        .withColumn("basket", F.explode(F.col("basket")))
        .join(
            spark.createDataFrame(product_data)
            .withColumnRenamed("product_id", "basket")
            .withColumnRenamed("new_product_id", "basket_renamed"),
            on="basket"
        )
        .groupby("case_id")
        .agg(
            F.collect_list(F.col("basket")).alias("basket"),
            F.collect_list(F.col("basket_renamed")).alias("basket_renamed")
        )
    ).show()
    

    Output:

    |case_id|         basket| basket_renamed|
    +-------+---------------+---------------+
    |      1|[546, 689, 946]|[S12, S74, S34]|
    |      2|     [546, 799]|     [S12, S56]|
    +-------+---------------+---------------+