Search code examples
pyspark

pyspark transform to find offset start and end


In Pyspark, trying to find offsets based on length array column "Col1". Don't want to use UDF, so trying to get solution using transform. But facing errors. Please suggest any workaround

Col1            Offset
[3,4,6,2,1]     [[0,3],[4,8],[9,15],[16,18],[19,20]]
[10,5,4,3,2]    [[0,10],[11,16],[17,21],[22,25],[26,28]]

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, expr


 spark = SparkSession.builder \
.appName("Calculate Offset Column") \
.getOrCreate()


 data = [([3.0, 4.0, 6.0, 2.0, 1.0],),
    ([10.0, 5.0, 4.0, 3.0, 2.0],)]


 df = spark.createDataFrame(data, ["Col1"])
 df = df.withColumn("Offsets", 
               f.expr("""transform(Col1, (x, i) -> struct(coalesce(sum(Col1) over (order by i rows between unbounded preceding and current row) - x, 0) as start, 
                                                         sum(Col1) over (order by i rows between unbounded preceding and current row) as end))"""))

Error: Resolved attribute(s) i#462 missing from Col1#454 in operator !Window [Col1#454, transform(Col1#454, lambdafunction(struct(start, coalesce((sum(cast(Col1 as double)) windowspecdefinition(lambda i#462 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) - lambda x#461), cast(0 as double)), end, sum(cast(Col1 as double)) windowspecdefinition(lambda i#462 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$()))), lambda x#461, lambda i#462, false)) AS Offsets#458], [lambda i#462 ASC NULLS FIRST].;


Solution

  • from pyspark.sql import SparkSession
    import pyspark.sql.functions as F 
    
    spark = SparkSession.builder \
    .appName("Calculate Offset Column") \
    .getOrCreate()
    
    
    data = [([3.0, 4.0, 6.0, 2.0, 1.0],),
        ([10.0, 5.0, 4.0, 3.0, 2.0],)]
    
    
    df = spark.createDataFrame(data, ["Col1"])
    
    df = (df
           .withColumn("Offsets_tmp", 
               F.expr("""transform(Col1, (x, i) -> (i,x))"""))
           .select("Col1",F.explode("Offsets_tmp").alias("expl"))
           .selectExpr(
               "Col1"
               ,"SUM(expl.x) OVER (PARTITION BY Col1 ORDER BY expl)+expl.i right_side"
               ,"CASE WHEN expl.i = 0 THEN 0 ELSE right_side-expl.x END left_side"
               ,"ARRAY(left_side, right_side) arr1"
           )
           .groupBy("Col1")
           .agg(F.collect_list("arr1").alias("Offset"))
           .select("Col1",F.array_sort("Offset").alias("Offset"))
          )
    
    df.show(truncate=False)
    
    +--------------------------+---------------------------------------------------------------------+
    |Col1                      |Offset                                                               |
    +--------------------------+---------------------------------------------------------------------+
    |[3.0, 4.0, 6.0, 2.0, 1.0] |[[0.0, 3.0], [4.0, 8.0], [9.0, 15.0], [16.0, 18.0], [19.0, 20.0]]    |
    |[10.0, 5.0, 4.0, 3.0, 2.0]|[[0.0, 10.0], [11.0, 16.0], [17.0, 21.0], [22.0, 25.0], [26.0, 28.0]]|
    +--------------------------+---------------------------------------------------------------------+
    

    Tested and works on Spark 3.5.0.

    Note: it might not be the prettiest option, you could replace most of the selectExpr with some nice PySpark code (e.g., pyspark.sql.Window).
    My personal preference is for SQL in terms of Window functions, though, as they are a bit less verbose in my opinion.

    EDIT: added array_sort. Not sure if it matters to you or not, but collect_list might return the elements in a different order after a shuffle. Since the Offset column is expected to be monotonically ascending, then array_sort can be used to have a deterministic output.