Search code examples
pythonlistpysparkisin

Check if any of the values from list are in pyspark column's list


I have this problem with my pyspark dataframe, I created a column with collect_list() by doing normal groupBy agg and I want to write something that would return Boolean with information if at least 1 of the values in this list is in some other list of "constants":

# this is just an example of data

data = [
    (111, ["A", "B", "C"]),
    (222, ["C", "D", "E"]),
    (333, ["D", "E", "F"]),
]

schema = ["id", "my_list"]

df = spark_session.createDataFrame(data, schema=schema)

# this list is for the comparsion
constants = ["A", "B", "C", "D"]

# here I want to check if at least 1 element in list within a column is in constants
contains_any_udf = udf(lambda x: F.any(item in const_list for item in x), BooleanType())

df_result = df.withColumn("is_in_col", contains_any_udf(df["my_list"]))

Is there any better way? I tried array_contains, array_intersect, but with poor result.

What Im expecting is same df with additional column that would contain True if at least 1 value from column "my_list" is within list of constants


Solution

  • What you need is the arrays_overlap function.

    import pyspark.sql.functions as F
    ...
    df = df.withColumn('is_in_col', F.arrays_overlap('my_list', F.array([F.lit(e) for e in constants])))