I'm working on a PySpark Dataframe aggregation where I am trying to dynamically generate condition in the 'when' clause based on a list of columns. However, I am encountering a syntax error. Here is a snippet of my code:
from pyspark.sql import functions as F
group_by_columns = ["class1", "class2"]
compare_columns = ["outliers"]
aggregations = [
F.count(F.when(F.col(f"latest_{compare_columns[0]}") == True, True)).alias("Count_outliers_latest_compare_true"),
F.count(F.when(F.col(f"previous_{compare_columns[0]}") == True, True)).alias("Count_outliers_previous_compare_true"),
F.count(F.when(
(F.col(f"latest_{compare_columns[0]}") == True) &
(F.col(f"previous_{compare_columns[0]}") == True) &
*((F.col(f"previous_{col}") == F.col(f"latest_{col}")) for col in group_by_columns),
True
)).alias("Count_outliers_both_compare_true_and_group_values_unchanged"),
F.count('*').alias('Total_count_per_class')
]
I receive a syntax error related to the use of * in the when clause. I've also tried to use [ ] instead of ( ), but I received the same issue. How can I correct this to avoid the syntax error?
Error: SyntaxError: invalid syntax (, line 10) File <->:10 *((F.col(f"previous_{col}") == F.col(f"latest_{col}")) for col in group_by_columns), ^ SyntaxError: invalid syntax
More clarification on intended usage for the list comprehension at
*((F.col(f"previous_{col}") == F.col(f"latest_{col}")) for col in group_by_columns)
With this code I intended it to dynamically add conditions based on the items in group_by_columns. For example with just one item in group_by_columns it should return:
(F.col(f"previous_group_by_columns[0]") == F.col(f"latest_group_by_columns[0]"))
If group_by_columns has two items:
(F.col(f"previous_group_by_columns[0]") == F.col(f"latest_group_by_columns[0]")) & (F.col(f"previous_group_by_columns[1]") == F.col(f"latest_group_by_columns[1]"))
If group_by_columns has three items:
(F.col(f"previous_group_by_columns[0]") == F.col(f"latest_group_by_columns[0]")) & (F.col(f"previous_group_by_columns[1]") == F.col(f"latest_group_by_columns[1]")) & (F.col(f"previous_group_by_columns[2]") == F.col(f"latest_group_by_columns[2]"))
etc.
By using reduce, I was able to resolve:
# Aggregations
aggregations = [
F.count(F.when(F.col(f"latest_{self.compare_columns[0]}") == True, True)).alias("Count_outliers_latest_compare_true"),
F.count(F.when(F.col(f"previous_{self.compare_columns[0]}") == True, True)).alias("Count_outliers_previous_compare_true"),
F.count(F.when(
(F.col(f"latest_{self.compare_columns[0]}") == True) &
(F.col(f"previous_{self.compare_columns[0]}") == True) &
reduce(lambda acc, col: acc & (F.col(f"previous_{col}") == F.col(f"latest_{col}")), self.group_by_columns, F.lit(True)),
True
)).alias("Count_outliers_both_compare_true_and_group_values_unchanged"),
F.count('*').alias('Total_count_per_class')
]