Search code examples
pythonpysparklag

Is it possible to write self referencing column in pyspark


I'm writing small poc trying to rewrite piece of logic written in python to pyspark, where im processing logs stored in sqlite one by one:

logs = [...]
processed_logs = []
previous_log = EmptyDecoratedLog() #empty
for log in logs:
    processed_log = with_outlet_value_closed(log, previous_log)
    previous_log = processed_log 
    processed_logs.append(processed_log)

and

def with_outlet_value_closed(current_entry: DecoratedLog, previous_entry: DecoratedLog):
    if current_entry.sourceName == "GS2":
        self.outletValveClosed = current_entry.eventData
    else:
        self.outletValveClosed = previous_entry.outletValveClosed

which I wanted to represent in pyspark api as:

import pyspark.sql.functions as f
window = W.orderBy("ID") #where ID is unique id on those logs
df.withColumn("testValveOpened",
                f.when((f.col("sourceName") == "GS2"), f.col("eventData"))
                .otherwise(f.lag("testValveOpened").over(window)),
                )

but this leads to AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name outletValveClosed cannot be resolved.

So my question is: Is it possible to represent such code where the value of a current row depends from previous row of the same column(i know that this will result in all records being processed on a single thread, but thats fine)

I've tried adding initialization of a column

df = df.withColumn("testValveOpened", f.lit(0))
df.withColumn("testValveOpened",
                f.when((f.col("sourceName") == "GS2"), f.col("eventData"))
                .otherwise(f.lag("testValveOpened").over(window)),
                )

but then I'm getting

ID |sourceName|eventData|testValveOpened
1  |GS3       |1        |0
2  |GS2       |1        |1
3  |GS2       |8        |8
4  |GS1       |1        |0
5  |GS2       |2        |0
6  |ABC       |0        |0
7  |B123      |0        |0
8  |B423      |0        |0
9  |PTSD      |168      |0
10 |XCD       |0        |0

I would like to get

ID |sourceName|eventData|testValveOpened
1  |GS3       |1        |0
2  |GS2       |1        |1
3  |GS2       |8        |8
4  |GS1       |1        |8
5  |GS2       |2        |2
6  |ABC       |0        |2
7  |B123      |0        |2
8  |B423      |0        |2
9  |PTSD      |168      |2
10 |XCD       |0        |2  

so when there's GS2 take value of eventData, otherwise cary value from previous testValueOpened


Solution

  • You have to rewrite the logic a bit, as you cannot update each row 'one by one'. First check for HS2:

    df.withColumn("testValveOpened", f.when(f.col("sourceName" == "GS2"), f.lit(1)).otherwise(0))
    

    then do a cumulative sum and comparison to see if a GS2 was present before:

    df.withColumn("testValveOpened", f.sum("testValveOpened").over(window) > 1)