Search code examples
loopsapache-sparkpysparkapache-spark-sqlsql-update

Update Spark Dataframe row by row


Consider the following table/dataframe:

|------------------|
|date       | value|
|------------------|
|2022-01-08 | 2    |
|2022-01-09 | 4    |
|2022-01-10 | 6    |
|2022-01-11 | 8    |
-------------------|

And the following SQL query:

WHILE (@start_date <= @end_date)
BEGIN
    update t1 set value = 
        IIF(ISNULL(avg_value,0) < 2, 0,1)
    from #table t1
    outer apply (
        select 
            top 1 value as avg_value
        FROM 
            #table t2
        WHERE
            value >= 2 AND
            t2.date < t1.date
        ORDER BY date DESC
    ) t3
    where t1.date = @start_date
    SET @start_date = dateadd(day,1, @start_date)
END

I know my output is:

|------------------------------|
|date       | value | avg_value|
|------------------------------|
|2022-01-08 | 0     | null     |
|2022-01-09 | 0     | 0        |
|2022-01-10 | 0     | 0        |
|2022-01-11 | 0     | 0        |
|------------------------------|

The query runs an outer apply for each date, so the table is updated line-by-line. It is worth mentioning that the value updated is retrieved within outer apply.

In Spark, I get the values from outer apply using Window function and store it in an auxiliary column:

|-------------------------------|
|date       | value | avg_value |
|-------------------------------|
|2022-01-08 | 0     | null      |
|2022-01-09 | 4     | 2         |
|2022-01-10 | 6     | 4         |
|2022-01-11 | 8     | 6         |
|-------------------------------|

Then I use withColumn to perform the update on value column, my output is:

|-------------------|
|date       | value |
|--------------------
|2022-01-08 | 0     |
|2022-01-09 | 1     |
|2022-01-10 | 1     |
|2022-01-11 | 1     |
|-------------------|

I KNOW my Spark output is different from SQL output, because SQL performs the update in each iteration, and in Spark's case I'm doing the update after all the avg_value are calculated.

MY QUESTION IS:

Is there a way to perform this query without using while loops, more specifically, is there a way to use update row-by-row in Spark?

My original DF has about 300K lines and I'm avoiding to use loops due to performance reasons.


Solution

  • You say, you have 300K lines. I doubt all of them contain different dates, so I assume you have certain groups. The following is the example dataframe I will be using. I have intentionally added groups with different cases:

    from pyspark.sql import functions as F, Window as W
    
    df = spark.createDataFrame(
        [(1, '2022-01-08', 2),    # 0
         (1, '2022-01-09', 4),    # 1
         (1, '2022-01-10', 6),    # 1
         (1, '2022-01-11', 8),    # 1
    
         (2, '2022-01-08', 0),    # 0
         (2, '2022-01-09', 2),    # 0
         (2, '2022-01-10', 6),    # 1
    
         (3, '2022-01-08', 4),    # 0
         (3, '2022-01-09', 6),    # 1
         (3, '2022-01-10', 8),    # 1
    
         (4, '2022-01-08', 0),    # 0
         (4, '2022-01-09', 6),    # 1
         (4, '2022-01-10', None), # 0
         (4, '2022-01-11', 6)],   # 1
        ['id', 'date', 'value'])
    

    In comments, I have provided the expected result.

    What I try to prove: Spark is not intended to implement loops. Almost any logic could be rewritten not to use loops per se.


    Window functions approach

    In the provided script, you have logic which could be rewritten to do the same, but using a simpler algorithm without looping: a window function and a conditional statement.

    w = W.partitionBy('id').orderBy('date')
    df.withColumn(
        'value',
        F.when((F.row_number().over(w) != 1) & (F.col('value') > 2), 1).otherwise(0)
    ).show()
    # +---+----------+-----+
    # |id |date      |value|
    # +---+----------+-----+
    # |1  |2022-01-08|0    |
    # |1  |2022-01-09|1    |
    # |1  |2022-01-10|1    |
    # |1  |2022-01-11|1    |
    # |2  |2022-01-08|0    |
    # |2  |2022-01-09|0    |
    # |2  |2022-01-10|1    |
    # |3  |2022-01-08|0    |
    # |3  |2022-01-09|1    |
    # |3  |2022-01-10|1    |
    # |4  |2022-01-08|0    |
    # |4  |2022-01-09|1    |
    # |4  |2022-01-10|0    |
    # |4  |2022-01-11|1    |
    # +---+----------+-----+
    

    "Loops" in higher-order function aggregate

    The function aggregate takes an array, "loops" through every element and returns one value (here, this value is made to be array too).

    The lambda function performs array_union, which makes a union of arrays having identic schemas.

    df = df.groupBy('id').agg(F.array_sort(F.collect_list(F.struct('date', 'value'))).alias('a'))
    df = df.withColumn(
        'a',
        F.slice(
            F.aggregate(
                'a',
                F.expr("array(struct(cast(null as string) date, 0 value))"),
                lambda acc, x: F.array_union(
                    acc,
                    F.array(x.withField(
                        'value',
                        F.when(F.element_at(acc, -1)['date'].isNotNull() & (x['value'] > 2), 1).otherwise(0)
                    ))
                )
            ),
            2, F.size('a')
        )
    )
    df = df.selectExpr("id", "inline(a)")
    
    df.show()
    # +---+----------+-----+
    # | id|      date|value|
    # +---+----------+-----+
    # |  1|2022-01-08|    0|
    # |  1|2022-01-09|    1|
    # |  1|2022-01-10|    1|
    # |  1|2022-01-11|    1|
    # |  2|2022-01-08|    0|
    # |  2|2022-01-09|    0|
    # |  2|2022-01-10|    1|
    # |  3|2022-01-08|    0|
    # |  3|2022-01-09|    1|
    # |  3|2022-01-10|    1|
    # |  4|2022-01-08|    0|
    # |  4|2022-01-09|    1|
    # |  4|2022-01-10|    0|
    # |  4|2022-01-11|    1|
    # +---+----------+-----+
    

    This way you can "loop" through elements of an array. But be cautious regarding the size of the arrays, as they are contained in one cluster node.