Search code examples
pythonpyspark

Syntax error in PySpark Dataframe aggregation with dynamic conditions in 'when' clause


I'm working on a PySpark Dataframe aggregation where I am trying to dynamically generate condition in the 'when' clause based on a list of columns. However, I am encountering a syntax error. Here is a snippet of my code:

from pyspark.sql import functions as F

group_by_columns = ["class1", "class2"]
compare_columns = ["outliers"]

aggregations = [
    F.count(F.when(F.col(f"latest_{compare_columns[0]}") == True, True)).alias("Count_outliers_latest_compare_true"),
    F.count(F.when(F.col(f"previous_{compare_columns[0]}") == True, True)).alias("Count_outliers_previous_compare_true"),
    F.count(F.when(
        (F.col(f"latest_{compare_columns[0]}") == True) &
        (F.col(f"previous_{compare_columns[0]}") == True) &
        *((F.col(f"previous_{col}") == F.col(f"latest_{col}")) for col in group_by_columns),
        True
    )).alias("Count_outliers_both_compare_true_and_group_values_unchanged"),
    F.count('*').alias('Total_count_per_class')
]

I receive a syntax error related to the use of * in the when clause. I've also tried to use [ ] instead of ( ), but I received the same issue. How can I correct this to avoid the syntax error?

Error: SyntaxError: invalid syntax (, line 10) File <->:10 *((F.col(f"previous_{col}") == F.col(f"latest_{col}")) for col in group_by_columns), ^ SyntaxError: invalid syntax

More clarification on intended usage for the list comprehension at

*((F.col(f"previous_{col}") == F.col(f"latest_{col}")) for col in group_by_columns)

With this code I intended it to dynamically add conditions based on the items in group_by_columns. For example with just one item in group_by_columns it should return:

(F.col(f"previous_group_by_columns[0]") == F.col(f"latest_group_by_columns[0]"))

If group_by_columns has two items:

(F.col(f"previous_group_by_columns[0]") == F.col(f"latest_group_by_columns[0]")) & (F.col(f"previous_group_by_columns[1]") == F.col(f"latest_group_by_columns[1]"))

If group_by_columns has three items:

(F.col(f"previous_group_by_columns[0]") == F.col(f"latest_group_by_columns[0]")) & (F.col(f"previous_group_by_columns[1]") == F.col(f"latest_group_by_columns[1]")) & (F.col(f"previous_group_by_columns[2]") == F.col(f"latest_group_by_columns[2]"))

etc.


Solution

  • By using reduce, I was able to resolve:

    # Aggregations
    aggregations = [
        F.count(F.when(F.col(f"latest_{self.compare_columns[0]}") == True, True)).alias("Count_outliers_latest_compare_true"),
        F.count(F.when(F.col(f"previous_{self.compare_columns[0]}") == True, True)).alias("Count_outliers_previous_compare_true"),
        F.count(F.when(
            (F.col(f"latest_{self.compare_columns[0]}") == True) &
            (F.col(f"previous_{self.compare_columns[0]}") == True) &
            reduce(lambda acc, col: acc & (F.col(f"previous_{col}") == F.col(f"latest_{col}")), self.group_by_columns, F.lit(True)),
            True
        )).alias("Count_outliers_both_compare_true_and_group_values_unchanged"),
        F.count('*').alias('Total_count_per_class')
    ]