Search code examples
pythonpyspark

Create a dynamic case when statement based on pyspark dataframe


I have a dataframe called df with features, col1, col2, col3. Their values should be combined and produce a result. What result each combination will produce is defined in mapping_table.

However, mapping_table sometimes has the value '*'. This mean that this feature can have any value, it doesn't affect the result.

This makes a join impossible(?) to make, since I need to evaluate which features to use in the join for every row.

What would be a good pyspark solution for this problem?

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a Spark session
spark = SparkSession.builder.appName("example").getOrCreate()

# Example DataFrames
map_data = [('a', 'b', 'c', 'good'), ('a', 'a', '*', 'very good'), 
          ('b', 'd', 'c', 'bad'), ('a', 'b', 'a', 'very good'),
          ('c', 'c', '*', 'very bad'), ('a', 'b', 'b', 'bad')]

columns = ["col1", "col2", 'col3', 'result']

mapping_table = spark.createDataFrame(X, columns)


data =[[('a', 'b', 'c'), ('a', 'a', 'b' ), 
        ('c', 'c', 'a' ), ('c', 'c', 'b' ),
        ('a', 'b', 'b'), ('a', 'a', 'd')
      ]]

columns = ["col1", "col2", 'col3']
df = spark.createDataFrame(data, columns)

Solution

  • Transform map_data into a case statement:

    ressql = 'case '
    for m in map_data:
        p = [f"{p[0]} = '{p[1]}'" for p in zip(columns, m[:3]) if p[1] != "*"]
        ressql = ressql + ' when ' + ' and '.join(p) + f" then '{m[3]}'"
    ressql = ressql + ' end'
    
    df.withColumn('result', F.expr(ressql)).show()
    

    ressql is now

    case 
      when col1 = 'a' and col2 = 'b' and col3 = 'c' then 'good' 
      when col1 = 'a' and col2 = 'a' then 'very good' 
      when col1 = 'b' and col2 = 'd' and col3 = 'c' then 'bad' 
      when col1 = 'a' and col2 = 'b' and col3 = 'a' then 'very good' 
      when col1 = 'c' and col2 = 'c' then 'very bad' 
      when col1 = 'a' and col2 = 'b' and col3 = 'b' then 'bad' 
    end
    

    Result:

    +----+----+----+---------+
    |col1|col2|col3|   result|
    +----+----+----+---------+
    |   a|   b|   c|     good|
    |   a|   a|   b|very good|
    |   c|   c|   a| very bad|
    |   c|   c|   b| very bad|
    |   a|   b|   b|      bad|
    |   a|   a|   d|very good|
    +----+----+----+---------+