I am wondering if there is a way to use .any()
in Pyspark?
I have the following code in Python, that essentially searches through a specific column of interest in a subset dataframe, and if any of those columns contain "AD"
, we do not want to process them.
Here is the code in Python:
index_list = [
df.query("id == @id").index
for trial in unique_trial_id_list
if ~(df.query("id == @trial")["unit"].str.upper().str.contains("AD").any()]
Here is a sample dataframe in Pandas.
ID=1
has the string 'AD'
associated with it, so we want to exclude it from processing. However, ID=2
does not have that string 'AD'
associated with it and thus we want to include it in further processing.
data = [
[1, "AD"],
[1, "BC"],
[1, "DE"],
[1, "FG"],
[2, "XY"],
[2, "BC"],
[2, "DE"],
[2, "FG"],
]
df = pd.DataFrame(data, columns=["ID", "Code"])
df
The rub is I do not know how to do this equivalent function in PySpark. I have been able to do a list comprehension for subsetting, and have been able to subset using contains('AD')
but am stuck when it comes to the any
part of things.
PySpark Code I've come up with:
id = id_list[0]
test = sdf.select(["ID", "Codes"]).filter(spark_fns.col("ID") == id).filter(~spark_fns.col("Codes").str.contains("AD"))
You can use Window function (the max of a boolean is true if there is at least one true value):
from pyspark.sql import functions as F, Window
df1 = df.withColumn(
"to_exclude",
~F.max(F.when(F.col("Code") == "AD", True).otherwise(False)).over(Window.partitionBy("ID"))
).filter(
F.col("to_exclude")
).drop("to_exclude")
df1.show()
# +---+----+
# | ID|Code|
# +---+----+
# | 2| XY|
# | 2| BC|
# | 2| DE|
# | 2| FG|
# +---+----+
Or groupby id
and using max
function along with when
expression to filter the id that contains AD
in Code
column, then join with original df :
from pyspark.sql import functions as F
filter_df = df.groupBy("id").agg(
F.max(F.when(F.col("Code") == "AD", True).otherwise(False)).alias("to_exclude")
).filter(F.col("to_exclude"))
df1 = df.join(filter_df, ["id"], "left_anti")
In Spark 3+, there is also a function any
:
from pyspark.sql import functions as F
filter_df = df.groupBy("id").agg(
F.expr("any(Code = 'AD')").alias("to_exclude")
).filter(F.col("to_exclude"))
df1 = df.join(filter_df, ["id"], "left_anti")