I am trying to derive new column "final".. The value of column is derived by referring to the previous value within a group. Within my data coA, colB, colC, colD forms a group and within this group the only value that will change is colE.
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("example") \
.getOrCreate()
# Sample data
data = [
("A", "2003-03-01", 1, 11, 1, 10, 0.1),
("A", "2003-03-01", 1, 11, 2, 10, 0.2),
("A", "2003-03-01", 1, 11, 3, 10, 0.3),
("A", "2003-03-01", 1, 11, 4, 10, 0.1),
("A", "2003-03-01", 1, 11, 5, 10, 0.2),
]
# Create DataFrame
df = spark.createDataFrame(data, ["colA", "colB", "colC", "colD", "colE", "value", "pred"])
# Show DataFrame
df.show()
# Output data
output = [
("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
("A", "2003-03-01", 1, 11, 3, 10, 0.3, 0.06),
("A", "2003-03-01", 1, 11, 4, 10, 0.1, 0.006),
("A", "2003-03-01", 1, 11, 5, 10, 0.2, 0.0012),
]
# Create DataFrame
output_df = spark.createDataFrame(output, ["colA", "colB", "colC", "colD", "colE", "value", "pred", "final"])
"Final" column is derived as follows: For the first instance within the group, value is equal to value * pred. The value for ths remaining instance within the group will be: Final (from previous row) * pred.
My current logic is as follows:
from pyspark.sql.window import Window
window_spec = Window.partition('colA', 'colB', 'colC', 'colD').orderBy('colE')
# This will derive value for first row within each group
a1 = input_df.withColumn('final', when(lag('colE').over(window_spec).isNull(), col('pred')*col('value')
a1 = a1.withColumn('final', when(col('final').isNotNull(), col('final'))
.otherwise(lag(col('final')).over(window_spec) * col('pred'))))
However, using the above logic, it only generate value for the first two row within each group.
# Output data
incorrect_output = [
("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
("A", "2003-03-01", 1, 11, 3, 10, 0.3, null),
("A", "2003-03-01", 1, 11, 4, 10, 0.1, null),
("A", "2003-03-01", 1, 11, 5, 10, 0.2, null),
]
What am I doing wrong? Can you please assist?
Check this out:
import pyspark.sql.functions as f
df = (
df
.withColumn("first_value", f.first("value").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
.withColumn("preds", f.collect_list("pred").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
.select(
df['*'],
(f.col('first_value') * f.expr('aggregate(preds, cast(1 as DOUBLE), (acc, x) -> acc * x)')).cast(FloatType()).alias('Final')
)
)
And the output is:
+----+----------+----+----+----+-----+----+------+
|colA| colB|colC|colD|colE|value|pred| Final|
+----+----------+----+----+----+-----+----+------+
| A|2003-03-01| 1| 11| 1| 10| 0.1| 1.0|
| A|2003-03-01| 1| 11| 2| 10| 0.2| 0.2|
| A|2003-03-01| 1| 11| 3| 10| 0.3| 0.06|
| A|2003-03-01| 1| 11| 4| 10| 0.1| 0.006|
| A|2003-03-01| 1| 11| 5| 10| 0.2|0.0012|
+----+----------+----+----+----+-----+----+------+