Search code examples
dataframeapache-sparkpysparkdynamicaggregate

Shift rows dynamically based on column value


Below is my input dataframe:

+---+----------+--------+
|ID |date      |shift_by|
+---+----------+--------+
|1  |2021-01-01|2       |
|1  |2021-02-05|2       |
|1  |2021-03-27|2       |
|2  |2022-02-28|1       |
|2  |2022-04-30|1       |
+---+----------+--------+

I need to groupBy "ID" and shift based on the "shift_by" column. In the end, the result should look like below:

+---+----------+----------+
|ID |date1     |date2     |
+---+----------+----------+
|1  |2021-01-01|2021-03-27|
|2  |2022-02-28|2022-04-30|
+---+----------+----------+

I have implemented the logic using UDF, but it makes my code slow. I would like to understand if this logic can be implemented without using UDF.

Below is a sample dataframe:

from datetime import datetime
from pyspark.sql.types import *

data2 = [(1, datetime.date(2021, 1, 1), datetime.date(2021, 3, 27)),
    (2, datetime.date(2022, 2, 28), datetime.date(2022, 4, 30))
]
schema = StructType([
    StructField("ID", IntegerType(), True),
    StructField("date1", DateType(), True),
    StructField("date2", DateType(), True),
])
df = spark.createDataFrame(data=data2, schema=schema)

Solution

  • based on the comments and chats, you can try to calculate first and last values of the lat/lon fields of concern.

    import pyspark.sql.functions as func
    from pyspark.sql.window import Window as wd
    import sys
    
    data_sdf. \
        withColumn('foo_first', func.first('foo').over(wd.partitionBy('id').orderBy('date').rowsBetween(-sys.maxsize, sys.maxsize))). \
        withColumn('foo_last', func.last('foo').over(wd.partitionBy('id').orderBy('date').rowsBetween(-sys.maxsize, sys.maxsize))). \
        select('id', 'foo_first', 'foo_last'). \
        dropDuplicates()
    

    OR, you can create structs and take min/max

    data_sdf = spark.createDataFrame(
        [(1, '2021-01-01', 2, 2),
         (1, '2021-02-05', 3, 2),
         (1, '2021-03-27', 4, 2),
         (2, '2022-02-28', 1, 5),
         (2, '2022-04-30', 5, 1)],
        ['ID', 'date', 'lat', 'lon'])
    
    data_sdf. \
        withColumn('dt_lat_lon_struct', func.struct('date', 'lat', 'lon')). \
        groupBy('id'). \
        agg(func.min('dt_lat_lon_struct').alias('min_dt_lat_lon_struct'),
            func.max('dt_lat_lon_struct').alias('max_dt_lat_lon_struct')
            ). \
        selectExpr('id', 
                   'min_dt_lat_lon_struct.lat as lat_first', 'min_dt_lat_lon_struct.lon as lon_first',
                   'max_dt_lat_lon_struct.lat as lat_last', 'max_dt_lat_lon_struct.lon as lon_last'
                   )
    
    # +---+---------+---------+--------+--------+
    # | id|lat_first|lon_first|lat_last|lon_last|
    # +---+---------+---------+--------+--------+
    # |  1|        2|        2|       4|       2|
    # |  2|        1|        5|       5|       1|
    # +---+---------+---------+--------+--------+