In Pyspark, trying to find offsets based on length array column "Col1". Don't want to use UDF, so trying to get solution using transform. But facing errors. Please suggest any workaround
Col1 Offset
[3,4,6,2,1] [[0,3],[4,8],[9,15],[16,18],[19,20]]
[10,5,4,3,2] [[0,10],[11,16],[17,21],[22,25],[26,28]]
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, expr
spark = SparkSession.builder \
.appName("Calculate Offset Column") \
.getOrCreate()
data = [([3.0, 4.0, 6.0, 2.0, 1.0],),
([10.0, 5.0, 4.0, 3.0, 2.0],)]
df = spark.createDataFrame(data, ["Col1"])
df = df.withColumn("Offsets",
f.expr("""transform(Col1, (x, i) -> struct(coalesce(sum(Col1) over (order by i rows between unbounded preceding and current row) - x, 0) as start,
sum(Col1) over (order by i rows between unbounded preceding and current row) as end))"""))
Error: Resolved attribute(s) i#462 missing from Col1#454 in operator !Window [Col1#454, transform(Col1#454, lambdafunction(struct(start, coalesce((sum(cast(Col1 as double)) windowspecdefinition(lambda i#462 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) - lambda x#461), cast(0 as double)), end, sum(cast(Col1 as double)) windowspecdefinition(lambda i#462 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$()))), lambda x#461, lambda i#462, false)) AS Offsets#458], [lambda i#462 ASC NULLS FIRST].;
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
spark = SparkSession.builder \
.appName("Calculate Offset Column") \
.getOrCreate()
data = [([3.0, 4.0, 6.0, 2.0, 1.0],),
([10.0, 5.0, 4.0, 3.0, 2.0],)]
df = spark.createDataFrame(data, ["Col1"])
df = (df
.withColumn("Offsets_tmp",
F.expr("""transform(Col1, (x, i) -> (i,x))"""))
.select("Col1",F.explode("Offsets_tmp").alias("expl"))
.selectExpr(
"Col1"
,"SUM(expl.x) OVER (PARTITION BY Col1 ORDER BY expl)+expl.i right_side"
,"CASE WHEN expl.i = 0 THEN 0 ELSE right_side-expl.x END left_side"
,"ARRAY(left_side, right_side) arr1"
)
.groupBy("Col1")
.agg(F.collect_list("arr1").alias("Offset"))
.select("Col1",F.array_sort("Offset").alias("Offset"))
)
df.show(truncate=False)
+--------------------------+---------------------------------------------------------------------+
|Col1 |Offset |
+--------------------------+---------------------------------------------------------------------+
|[3.0, 4.0, 6.0, 2.0, 1.0] |[[0.0, 3.0], [4.0, 8.0], [9.0, 15.0], [16.0, 18.0], [19.0, 20.0]] |
|[10.0, 5.0, 4.0, 3.0, 2.0]|[[0.0, 10.0], [11.0, 16.0], [17.0, 21.0], [22.0, 25.0], [26.0, 28.0]]|
+--------------------------+---------------------------------------------------------------------+
Tested and works on Spark 3.5.0.
Note: it might not be the prettiest option, you could replace most of the selectExpr
with some nice PySpark code (e.g., pyspark.sql.Window
).
My personal preference is for SQL in terms of Window functions, though, as they are a bit less verbose in my opinion.
EDIT: added array_sort
. Not sure if it matters to you or not, but collect_list
might return the elements in a different order after a shuffle. Since the Offset
column is expected to be monotonically ascending, then array_sort
can be used to have a deterministic output.