Search code examples
arraysapache-sparkpysparkapache-spark-sqlcase-when

How to use when statement and array_contains in Pyspark to create a new column based on conditions?


I am trying to use a filter, a case-when statement and an array_contains expression to filter and flag columns in my dataset and am trying to do so in a more efficient way than I currently am.

I have been unable to successfully string together these 3 elements and was hoping someone could advise as my current method works but isn't efficient.

Currently:

data = [["a", [1, 2, 3]], ["b", [1, 2, 8]], ["c", [3, 5, 4]], ["d", [8, 1, 4]]]

df = pd.DataFrame(data, columns=["product", "list_of_values"])
sdf = spark.createDataFrame(df)

# partially flag using array_contains to determine if element is within list_of_values
partially_flagged_sdf = (
    sdf.withColumn(
        "contains_element1",
        spark_fns.array_contains(
            sdf.list_of_values, "1"
        ),
    )
    .withColumn(
        "contains_element2",
        spark_fns.array_contains(
            sdf.list_of_values, "2"
        ),
    )
    .withColumn(
        "contains_element3",
        spark_fns.array_contains(
            sdf.list_of_values, "3"
        ),
    )
    .withColumn(
        "contains_element4",
        spark_fns.array_contains(
            sdf.list_of_values, "4"
        ),
    )
)

# using case_when and filtering, add additional flag if product == a, and list_of_values contains element 1 or 2
flagged_sdf = partially_flagged_sdf.withColumn("proda_contains_el1", spark_fns.when((spark_fns.col("product) == 'a') & & (
        (spark_fns.col("contains_element1") == True)
        | (spark_fns.col("contains_element2") == True)
    )),True).otherwise(False)

Solution

  • You can use arrays_overlap to check multiple elements:

    import pyspark.sql.functions as F
    
    df2 = sdf.withColumn(
        'newcol', 
        (F.col('product') == 'a') & 
        F.arrays_overlap('list_of_values', F.array(F.lit(1), F.lit(2)))
    )
    
    df2.show()
    +-------+--------------+------+
    |product|list_of_values|newcol|
    +-------+--------------+------+
    |      a|     [1, 2, 3]|  true|
    |      b|     [1, 2, 8]| false|
    |      c|     [3, 5, 4]| false|
    |      d|     [8, 1, 4]| false|
    +-------+--------------+------+