Search code examples
pythonpandasapache-sparkpysparkapache-spark-sql

Is there a .any() equivalent in PySpark?


I am wondering if there is a way to use .any() in Pyspark?

I have the following code in Python, that essentially searches through a specific column of interest in a subset dataframe, and if any of those columns contain "AD", we do not want to process them.

Here is the code in Python:

index_list = [
    df.query("id == @id").index 
    for trial in unique_trial_id_list 
    if ~(df.query("id == @trial")["unit"].str.upper().str.contains("AD").any()]

Here is a sample dataframe in Pandas.

ID=1 has the string 'AD' associated with it, so we want to exclude it from processing. However, ID=2 does not have that string 'AD' associated with it and thus we want to include it in further processing.

data = [
    [1, "AD"],
    [1, "BC"],
    [1, "DE"],
    [1, "FG"],
    [2, "XY"],
    [2, "BC"],
    [2, "DE"],
    [2, "FG"],
]
df = pd.DataFrame(data, columns=["ID", "Code"])
df

The rub is I do not know how to do this equivalent function in PySpark. I have been able to do a list comprehension for subsetting, and have been able to subset using contains('AD') but am stuck when it comes to the any part of things.

PySpark Code I've come up with:

id = id_list[0] 
test = sdf.select(["ID", "Codes"]).filter(spark_fns.col("ID") == id).filter(~spark_fns.col("Codes").str.contains("AD"))

Solution

  • You can use Window function (the max of a boolean is true if there is at least one true value):

    from pyspark.sql import functions as F, Window
    
    df1 = df.withColumn(
        "to_exclude",
        ~F.max(F.when(F.col("Code") == "AD", True).otherwise(False)).over(Window.partitionBy("ID"))
    ).filter(
        F.col("to_exclude")
    ).drop("to_exclude")
    
    df1.show()
    # +---+----+
    # | ID|Code|
    # +---+----+
    # |  2|  XY|
    # |  2|  BC|
    # |  2|  DE|
    # |  2|  FG|
    # +---+----+
    

    Or groupby id and using max function along with when expression to filter the id that contains AD in Code column, then join with original df :

    from pyspark.sql import functions as F
    
    filter_df = df.groupBy("id").agg(
        F.max(F.when(F.col("Code") == "AD", True).otherwise(False)).alias("to_exclude")
    ).filter(F.col("to_exclude"))
    
    df1 = df.join(filter_df, ["id"], "left_anti")
    

    In Spark 3+, there is also a function any:

    from pyspark.sql import functions as F
    
    filter_df = df.groupBy("id").agg(
        F.expr("any(Code = 'AD')").alias("to_exclude")
    ).filter(F.col("to_exclude"))
    
    df1 = df.join(filter_df, ["id"], "left_anti")