I have a Spark data frame that contains a column of arrays with product ids from sold baskets.
import pandas as pd
import pyspark.sql.types as T
from pyspark.sql import functions as F
df_baskets = spark.createDataFrame(
[(1, ["546", "689", "946"]), (2, ["546", "799"] )],
("case_id","basket")
)
df_baskets.show()
#+-------+---------------+
#|case_id| basket|
#+-------+---------------+
#| 1|[546, 689, 946]|
#| 2| [546, 799]|
#+-------+---------------+
I would like to replace the product ids in each array with new ids given in a pandas data frame.
product_data = pd.DataFrame({
"product_id": ["546", "689", "946", "799"],
"new_product_id": ["S12", "S74", "S34", "S56"]
})
product_data
I was able to replace the values by applying a simple python function to the column that performs a lookup on the python data frame.
def get_new_id(product_id: str) -> str:
try:
row = product_data[product_data["product_id"] == product_id]
return row["new_product_id"].item()
except ValueError:
return product_id
apply_get = F.udf(lambda basket: [get_new_id(product) for product in basket], T.ArrayType(T.StringType()))
df_baskets = (
df_baskets
.withColumn('basket_renamed', apply_get(F.col('basket')))
)
df_baskets.show()
#+-------+---------------+---------------+
#|case_id| basket| basket_renamed|
#+-------+---------------+---------------+
#| 1|[546, 689, 946]|[S12, S74, S34]|
#| 2| [546, 799]| [S12, S56]|
#+-------+---------------+---------------+
However, this approach has proven to be quite slow in data frames containing several tens of millions of cases. Is there more efficient way to do this replacement (e.g. by using a different data structure than a pandas data frame or a different method)?
You could explode your original data and join on product_data
(after converting it to a spark frame)
(
df_baskets
.withColumn("basket", F.explode(F.col("basket")))
.join(
spark.createDataFrame(product_data)
.withColumnRenamed("product_id", "basket")
.withColumnRenamed("new_product_id", "basket_renamed"),
on="basket"
)
.groupby("case_id")
.agg(
F.collect_list(F.col("basket")).alias("basket"),
F.collect_list(F.col("basket_renamed")).alias("basket_renamed")
)
).show()
Output:
|case_id| basket| basket_renamed|
+-------+---------------+---------------+
| 1|[546, 689, 946]|[S12, S74, S34]|
| 2| [546, 799]| [S12, S56]|
+-------+---------------+---------------+