Consider the following table/dataframe:
|------------------|
|date | value|
|------------------|
|2022-01-08 | 2 |
|2022-01-09 | 4 |
|2022-01-10 | 6 |
|2022-01-11 | 8 |
-------------------|
And the following SQL query:
WHILE (@start_date <= @end_date)
BEGIN
update t1 set value =
IIF(ISNULL(avg_value,0) < 2, 0,1)
from #table t1
outer apply (
select
top 1 value as avg_value
FROM
#table t2
WHERE
value >= 2 AND
t2.date < t1.date
ORDER BY date DESC
) t3
where t1.date = @start_date
SET @start_date = dateadd(day,1, @start_date)
END
I know my output is:
|------------------------------|
|date | value | avg_value|
|------------------------------|
|2022-01-08 | 0 | null |
|2022-01-09 | 0 | 0 |
|2022-01-10 | 0 | 0 |
|2022-01-11 | 0 | 0 |
|------------------------------|
The query runs an outer apply
for each date, so the table is updated line-by-line. It is worth mentioning that the value updated is retrieved within outer apply
.
In Spark, I get the values from outer apply
using Window function
and store it in an auxiliary column:
|-------------------------------|
|date | value | avg_value |
|-------------------------------|
|2022-01-08 | 0 | null |
|2022-01-09 | 4 | 2 |
|2022-01-10 | 6 | 4 |
|2022-01-11 | 8 | 6 |
|-------------------------------|
Then I use withColumn
to perform the update on value
column, my output is:
|-------------------|
|date | value |
|--------------------
|2022-01-08 | 0 |
|2022-01-09 | 1 |
|2022-01-10 | 1 |
|2022-01-11 | 1 |
|-------------------|
I KNOW my Spark output is different from SQL output, because SQL performs the update in each iteration, and in Spark's case I'm doing the update after all the avg_value
are calculated.
MY QUESTION IS:
Is there a way to perform this query without using while loops, more specifically, is there a way to use update row-by-row in Spark?
My original DF has about 300K lines and I'm avoiding to use loops due to performance reasons.
You say, you have 300K lines. I doubt all of them contain different dates, so I assume you have certain groups. The following is the example dataframe I will be using. I have intentionally added groups with different cases:
from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
[(1, '2022-01-08', 2), # 0
(1, '2022-01-09', 4), # 1
(1, '2022-01-10', 6), # 1
(1, '2022-01-11', 8), # 1
(2, '2022-01-08', 0), # 0
(2, '2022-01-09', 2), # 0
(2, '2022-01-10', 6), # 1
(3, '2022-01-08', 4), # 0
(3, '2022-01-09', 6), # 1
(3, '2022-01-10', 8), # 1
(4, '2022-01-08', 0), # 0
(4, '2022-01-09', 6), # 1
(4, '2022-01-10', None), # 0
(4, '2022-01-11', 6)], # 1
['id', 'date', 'value'])
In comments, I have provided the expected result.
What I try to prove: Spark is not intended to implement loops. Almost any logic could be rewritten not to use loops per se.
Window functions approach
In the provided script, you have logic which could be rewritten to do the same, but using a simpler algorithm without looping: a window function and a conditional statement.
w = W.partitionBy('id').orderBy('date')
df.withColumn(
'value',
F.when((F.row_number().over(w) != 1) & (F.col('value') > 2), 1).otherwise(0)
).show()
# +---+----------+-----+
# |id |date |value|
# +---+----------+-----+
# |1 |2022-01-08|0 |
# |1 |2022-01-09|1 |
# |1 |2022-01-10|1 |
# |1 |2022-01-11|1 |
# |2 |2022-01-08|0 |
# |2 |2022-01-09|0 |
# |2 |2022-01-10|1 |
# |3 |2022-01-08|0 |
# |3 |2022-01-09|1 |
# |3 |2022-01-10|1 |
# |4 |2022-01-08|0 |
# |4 |2022-01-09|1 |
# |4 |2022-01-10|0 |
# |4 |2022-01-11|1 |
# +---+----------+-----+
"Loops" in higher-order function aggregate
The function aggregate
takes an array, "loops" through every element and returns one value (here, this value is made to be array too).
The lambda function performs array_union
, which makes a union of arrays having identic schemas.
df = df.groupBy('id').agg(F.array_sort(F.collect_list(F.struct('date', 'value'))).alias('a'))
df = df.withColumn(
'a',
F.slice(
F.aggregate(
'a',
F.expr("array(struct(cast(null as string) date, 0 value))"),
lambda acc, x: F.array_union(
acc,
F.array(x.withField(
'value',
F.when(F.element_at(acc, -1)['date'].isNotNull() & (x['value'] > 2), 1).otherwise(0)
))
)
),
2, F.size('a')
)
)
df = df.selectExpr("id", "inline(a)")
df.show()
# +---+----------+-----+
# | id| date|value|
# +---+----------+-----+
# | 1|2022-01-08| 0|
# | 1|2022-01-09| 1|
# | 1|2022-01-10| 1|
# | 1|2022-01-11| 1|
# | 2|2022-01-08| 0|
# | 2|2022-01-09| 0|
# | 2|2022-01-10| 1|
# | 3|2022-01-08| 0|
# | 3|2022-01-09| 1|
# | 3|2022-01-10| 1|
# | 4|2022-01-08| 0|
# | 4|2022-01-09| 1|
# | 4|2022-01-10| 0|
# | 4|2022-01-11| 1|
# +---+----------+-----+
This way you can "loop" through elements of an array. But be cautious regarding the size of the arrays, as they are contained in one cluster node.