Search code examples

Derive value based on previous row

I am trying to derive new column "final".. The value of column is derived by referring to the previous value within a group. Within my data coA, colB, colC, colD forms a group and within this group the only value that will change is colE.

Create SparkSession

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("example") \

# Sample data
data = [
    ("A", "2003-03-01", 1, 11, 1, 10, 0.1),
    ("A", "2003-03-01", 1, 11, 2, 10, 0.2),
    ("A", "2003-03-01", 1, 11, 3, 10, 0.3),
    ("A", "2003-03-01", 1, 11, 4, 10, 0.1),
    ("A", "2003-03-01", 1, 11, 5, 10, 0.2),

# Create DataFrame
df = spark.createDataFrame(data, ["colA", "colB", "colC", "colD", "colE", "value", "pred"])

# Show DataFrame

# Output data
output = [
    ("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
    ("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
    ("A", "2003-03-01", 1, 11, 3, 10, 0.3, 0.06),
    ("A", "2003-03-01", 1, 11, 4, 10, 0.1, 0.006),
    ("A", "2003-03-01", 1, 11, 5, 10, 0.2, 0.0012),

# Create DataFrame
output_df = spark.createDataFrame(output, ["colA", "colB", "colC", "colD", "colE", "value", "pred", "final"])

"Final" column is derived as follows: For the first instance within the group, value is equal to value * pred. The value for ths remaining instance within the group will be: Final (from previous row) * pred.

My current logic is as follows:

from pyspark.sql.window import Window
window_spec = Window.partition('colA', 'colB', 'colC', 'colD').orderBy('colE')
# This will derive value for first row within each group
a1 = input_df.withColumn('final', when(lag('colE').over(window_spec).isNull(), col('pred')*col('value')
a1 = a1.withColumn('final', when(col('final').isNotNull(), col('final'))
                    .otherwise(lag(col('final')).over(window_spec) * col('pred'))))

However, using the above logic, it only generate value for the first two row within each group.

# Output data
incorrect_output = [
    ("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
    ("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
    ("A", "2003-03-01", 1, 11, 3, 10, 0.3, null),
    ("A", "2003-03-01", 1, 11, 4, 10, 0.1, null),
    ("A", "2003-03-01", 1, 11, 5, 10, 0.2, null),

What am I doing wrong? Can you please assist?


  • Check this out:

    import pyspark.sql.functions as f
    df = (
        .withColumn("first_value", f.first("value").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
        .withColumn("preds", f.collect_list("pred").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
            (f.col('first_value') * f.expr('aggregate(preds, cast(1 as DOUBLE), (acc, x) -> acc * x)')).cast(FloatType()).alias('Final')

    And the output is:

    |colA|      colB|colC|colD|colE|value|pred| Final|
    |   A|2003-03-01|   1|  11|   1|   10| 0.1|   1.0|
    |   A|2003-03-01|   1|  11|   2|   10| 0.2|   0.2|
    |   A|2003-03-01|   1|  11|   3|   10| 0.3|  0.06|
    |   A|2003-03-01|   1|  11|   4|   10| 0.1| 0.006|
    |   A|2003-03-01|   1|  11|   5|   10| 0.2|0.0012|