Search code examples
apache-sparkpysparkapache-spark-sql

Conditional mapping in Pyspark


I have a PySpark DataFrame with 2M rows called inventory with the following columns:

category_id sub_category_id product_code product_name
1001 A001 X123 Gadget A
1001 A002 X456 Gadget B
2002 B003 Y123 Gadget C
3003 C000 Z123 Gadget D
3003 C002 Z456 Gadget E
3003 C003 Z789 Gadget F

I want to map sub_category_id based on condition dicts in category_id

  1. If category_id is 1001, map:

    • sub_category_id A001 to M001
    • sub_category_id A002 to M002
    • if sub_category_id is not in mapping, do nothing.
  2. If category_id is 2002, map:

    • sub_category_id B003 to N003 ...

Here is the mapping example:


mappings = [
    {
        "conditions": [{"column": "category_id", "values": ["1001"]}],
        "values_mapping": {
            "A001": "M001",
            "A002": "M002"
        }
    },
    {
        "conditions": [{"column": "category_id", "values": ["2002"]}],
        "values_mapping": {
            "B003": "N003"
        }
    },
    {
        "conditions": [{"column": "category_id", "values": ["3003"]}],
        "values_mapping": {
            "C001": "P001",
            "C002": "P002",
            "C003": "P003"
        }
    }
]

I want to implement this systematically in PySpark, using a configuration dictionary to define the conditions and mappings.

I tried using a for loop to filter each condition one by one, apply the mappings, and then union each filtered result. However, the performance was very poor.

How can I achieve this efficiently in PySpark?


Solution

  • This should be able to solve your problem assuming your mappings are as provided, later you can drop or preserve and rename the mapped sub category column

    df = inputDF.withColumn("mapped_sub_category_id", F.col("sub_category_id"))
    
    for rule in mappings:
        conditions = rule["conditions"]
        values_mapping = rule["values_mapping"]
    
        category_condition = F.col(conditions[0]["column"]).isin(conditions[0]["values"])
    
        for original_value, mapped_value in values_mapping.items():
            df = df.withColumn("mapped_sub_category_id",F.when(category_condition & (F.col("sub_category_id") == original_value), mapped_value).otherwise(F.col("mapped_sub_category_id")))