Search code examples
pysparkpyspark-pandas

PySpark Create a new lag() column from an existing column and fillna with existing column value


I am looking to convert my Pandas code to PySpark and create a new column with the existing one by grouping the data on 'session' and shifting data to get the next row value for 'next_timestamp'. But for the last row in every group, I am getting null value and I was able to overcome this in pandas by filling NA with existing column value. Need to achieve the same in PySpark.

Below is the expected output:

| session | timestamp | next\_timestamp |
| ------- | --------- | --------------- |
| 1       | 100       | 101             |
| 1       | 101       | 102             |
| 1       | 102       | 103             |
| 1       | 103       | 104             |
| 1       | 104       | 105             |
| 1       | 105       | 106             |
| 1       | 106       | 107             |
| 1       | 107       | 107             |
| 2       | 108       | 109             |
| 2       | 109       | 110             |
| 2       | 110       | 111             |
| 2       | 111       | 112             |
| 2       | 112       | 112             |
| 3       | 113       | 114             |
| 3       | 114       | 115             |
| 3       | 115       | 116             |
| 3       | 116       | 117             |
| 3       | 117       | 118             |
| 3       | 118       | 118             |

Below is the output that I am currently getting:

Code

df = df.withColumn('next_timestamp', lag('timestamp', -1, 0).over(Window.partitionBy('sessionId').orderBy('timestamp')))

Output

| session | timestamp | next\_timestamp |
| ------- | --------- | --------------- |
| 1       | 100       | 101             |
| 1       | 101       | 102             |
| 1       | 102       | 103             |
| 1       | 103       | 104             |
| 1       | 104       | 105             |
| 1       | 105       | 106             |
| 1       | 106       | 107             |
| 1       | 107       |                 |
| 2       | 108       | 109             |
| 2       | 109       | 110             |
| 2       | 110       | 111             |
| 2       | 111       | 112             |
| 2       | 112       |                 |
| 3       | 113       | 114             |
| 3       | 114       | 115             |
| 3       | 115       | 116             |
| 3       | 116       | 117             |
| 3       | 117       | 118             |
| 3       | 118       |                 |

I was able to achieve the expected output in pandas with the below logic

df['next_timestamp'] = df['timestamp'].groupby(df['sessionId']).shift(-1).fillna(df['timestamp'])

Need to convert the same in PySpark and fill NA with either ffill() or col['timestamp'] from that row.


Solution

  • I think a simple when-otherwise transformation should do the job:

    from pyspark.sql.functions import col, when
    
    df.withColumn("next_timestamp", when(col("next_timestamp").isNotNull(), col("next_timestamp")).otherwise(col("timestamp")))
    
    | session | timestamp | next_timestamp |
    | ------- | --------- | --------------- |
    | 1       | 100       | 101             |
    | 1       | 101       | 102             |
    | 1       | 102       | 103             |
    | 1       | 103       | 104             |
    | 1       | 104       | 105             |
    | 1       | 105       | 106             |
    | 1       | 106       | 107             |
    | 1       | 107       | 107             |
    | 2       | 108       | 109             |
    | 2       | 109       | 110             |
    | 2       | 110       | 111             |
    | 2       | 111       | 112             |
    | 2       | 112       | 112             |
    | 3       | 113       | 114             |
    | 3       | 114       | 115             |
    | 3       | 115       | 116             |
    | 3       | 116       | 117             |
    | 3       | 117       | 118             |
    | 3       | 118       | 118             |