I have a dataframe that has an array with doubles as values. Within the array, 1 or a sum of the numbers equals a certain target value, and I want to extract the values that either equal the value or can be summed to equal the value. I'd like to be able to do this in PySpark.
| Array | Target | NewArray |
| -----------------------|-----------|------------------|
| [0.0001,2.5,3.0,0.0031]| 0.0032 | [0.0001,0.0031] |
| [2.5,1.0,0.5,3.0] | 3.0 | [2.5, 0.5, 3.0] |
| [1.0,1.0,1.5,1.0] | 4.5 | [1.0,1.0,1.5,1.0]|
You can encapsulate the logic as an udf
and create NewArray
based on this.
I have borrowed the logic for identifying the elements of array summing to a target value from here.
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.sql.functions import udf
from decimal import Decimal
data = [([0.0001,2.5,3.0,0.0031], 0.0032),
([2.5, 1.0, 0.5, 3.0], 3.0),
([1.0, 1.0, 1.5, 1.0], 4.5),
([], 1.0),
(None, 1.0),
([1.0,2.0], None),]
df = spark.createDataFrame(data, ("Array", "Target", ))
@udf(returnType=ArrayType(DoubleType()))
def find_values_summing_to_target(array, target):
def subset_sum(numbers, target, partial, result):
s = sum(partial)
# check if the partial sum is equals to target
if s == target:
result.extend(partial)
if s >= target:
return # if we reach the number why bother to continue
for i in range(len(numbers)):
n = numbers[i]
remaining = numbers[i+1:]
subset_sum(remaining, target, partial + [n], result)
result = []
if array is not None and target is not None:
array = [Decimal(str(a)) for a in array]
subset_sum(array, Decimal(str(target)), [], result)
result = [float(r) for r in result]
return result
df.withColumn("NewArray", find_values_summing_to_target("Array", "Target")).show(200, False)
+--------------------------+------+--------------------+
|Array |Target|NewArray |
+--------------------------+------+--------------------+
|[1.0E-4, 2.5, 3.0, 0.0031]|0.0032|[1.0E-4, 0.0031] |
|[2.5, 1.0, 0.5, 3.0] |3.0 |[2.5, 0.5, 3.0] |
|[1.0, 1.0, 1.5, 1.0] |4.5 |[1.0, 1.0, 1.5, 1.0]|
|[] |1.0 |[] |
|null |1.0 |[] |
|[1.0, 2.0] |null |[] |
+--------------------------+------+--------------------+