I am trying to use a filter
, a case-when
statement and an array_contains
expression to filter and flag columns in my dataset and am trying to do so in a more efficient way than I currently am.
I have been unable to successfully string together these 3 elements and was hoping someone could advise as my current method works but isn't efficient.
Currently:
data = [["a", [1, 2, 3]], ["b", [1, 2, 8]], ["c", [3, 5, 4]], ["d", [8, 1, 4]]]
df = pd.DataFrame(data, columns=["product", "list_of_values"])
sdf = spark.createDataFrame(df)
# partially flag using array_contains to determine if element is within list_of_values
partially_flagged_sdf = (
sdf.withColumn(
"contains_element1",
spark_fns.array_contains(
sdf.list_of_values, "1"
),
)
.withColumn(
"contains_element2",
spark_fns.array_contains(
sdf.list_of_values, "2"
),
)
.withColumn(
"contains_element3",
spark_fns.array_contains(
sdf.list_of_values, "3"
),
)
.withColumn(
"contains_element4",
spark_fns.array_contains(
sdf.list_of_values, "4"
),
)
)
# using case_when and filtering, add additional flag if product == a, and list_of_values contains element 1 or 2
flagged_sdf = partially_flagged_sdf.withColumn("proda_contains_el1", spark_fns.when((spark_fns.col("product) == 'a') & & (
(spark_fns.col("contains_element1") == True)
| (spark_fns.col("contains_element2") == True)
)),True).otherwise(False)
You can use arrays_overlap
to check multiple elements:
import pyspark.sql.functions as F
df2 = sdf.withColumn(
'newcol',
(F.col('product') == 'a') &
F.arrays_overlap('list_of_values', F.array(F.lit(1), F.lit(2)))
)
df2.show()
+-------+--------------+------+
|product|list_of_values|newcol|
+-------+--------------+------+
| a| [1, 2, 3]| true|
| b| [1, 2, 8]| false|
| c| [3, 5, 4]| false|
| d| [8, 1, 4]| false|
+-------+--------------+------+