I have a dataframe called df
with features, col1, col2, col3
. Their values should be combined and produce a result. What result each combination will produce is defined in mapping_table
.
However, mapping_table sometimes has the value '*'. This mean that this feature can have any value, it doesn't affect the result.
This makes a join impossible(?) to make, since I need to evaluate which features to use in the join for every row.
What would be a good pyspark solution for this problem?
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Create a Spark session
spark = SparkSession.builder.appName("example").getOrCreate()
# Example DataFrames
map_data = [('a', 'b', 'c', 'good'), ('a', 'a', '*', 'very good'),
('b', 'd', 'c', 'bad'), ('a', 'b', 'a', 'very good'),
('c', 'c', '*', 'very bad'), ('a', 'b', 'b', 'bad')]
columns = ["col1", "col2", 'col3', 'result']
mapping_table = spark.createDataFrame(X, columns)
data =[[('a', 'b', 'c'), ('a', 'a', 'b' ),
('c', 'c', 'a' ), ('c', 'c', 'b' ),
('a', 'b', 'b'), ('a', 'a', 'd')
]]
columns = ["col1", "col2", 'col3']
df = spark.createDataFrame(data, columns)
Transform map_data
into a case statement:
ressql = 'case '
for m in map_data:
p = [f"{p[0]} = '{p[1]}'" for p in zip(columns, m[:3]) if p[1] != "*"]
ressql = ressql + ' when ' + ' and '.join(p) + f" then '{m[3]}'"
ressql = ressql + ' end'
df.withColumn('result', F.expr(ressql)).show()
ressql
is now
case
when col1 = 'a' and col2 = 'b' and col3 = 'c' then 'good'
when col1 = 'a' and col2 = 'a' then 'very good'
when col1 = 'b' and col2 = 'd' and col3 = 'c' then 'bad'
when col1 = 'a' and col2 = 'b' and col3 = 'a' then 'very good'
when col1 = 'c' and col2 = 'c' then 'very bad'
when col1 = 'a' and col2 = 'b' and col3 = 'b' then 'bad'
end
Result:
+----+----+----+---------+
|col1|col2|col3| result|
+----+----+----+---------+
| a| b| c| good|
| a| a| b|very good|
| c| c| a| very bad|
| c| c| b| very bad|
| a| b| b| bad|
| a| a| d|very good|
+----+----+----+---------+