Search code examples
apache-sparkpysparkapache-spark-sqltime-seriesinterpolation

How to interpolate time series based on time gap between non null values in PySpark


I would like to interpolate time series data. Thereby, the challenge is to interpolate only if the time interval between the existing values is not greater than a specified limit.

Input data

from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.driver.memory", "60g").getOrCreate()

df = spark.createDataFrame([{'timestamp': 1642205833225, 'value': 58.00},
                            {'timestamp': 1642205888654, 'value': float('nan')},
                            {'timestamp': 1642205899657, 'value': float('nan')},
                            {'timestamp': 1642205892970, 'value': 55.00},
                            {'timestamp': 1642206338180, 'value': float('nan')},
                            {'timestamp': 1642206353652, 'value': 56.45},
                            {'timestamp': 1642206853451, 'value': float('nan')},
                            {'timestamp': 1642207353652, 'value': 80.45}
                            ])
df.show()

+-------------+-----+
|    timestamp|value|
+-------------+-----+
|1642205833225| 58.0|
|1642205888654|  NaN|
|1642205899654|  NaN|
|1642205892970| 55.0|
|1642206338180|  NaN|
|1642206353652|56.45|
|1642206853451|  NaN|
|1642207353652|80.45|
+-------------+-----+

First I want to calculate the time gap to the next existing value (next_value - current_value).

+-------------+-----+---------------+
|    timestamp|value|timegap_to_next|
+-------------+-----+---------------+
|1642205833225| 58.0|          59745|
|1642205888654|  NaN|            NaN|
|1642205899657|  NaN|            NaN|
|1642205892970| 55.0|         460682|
|1642206338180|  NaN|            NaN|
|1642206353652|56.45|        1030300|
|1642206853451|  NaN|            NaN|
|1642207383952|80.45|            NaN|
+-------------+-----+---------------+

Based on the calculated Timegap the interpolation should be done. In this case the threshold is 500000.

Final Output:

+-------------+-----+---------------+
|    timestamp|value|timegap_to_next|
+-------------+-----+---------------+
|1642205833225| 58.0|          59745|
|1642205888654| 57.0|            NaN|
|1642205899657| 56.0|            NaN|
|1642205892970| 55.0|         460682|
|1642206338180|55.75|            NaN|
|1642206353652|56.45|        1030300|
|1642206853451|  NaN|            NaN|
|1642207383952|80.45|            NaN|
+-------------+-----+---------------+

Can anybody help me with this special case? That would be very nice!


Solution

  • Having this input dataframe:

    df = spark.createDataFrame([
        (1642205833225, 58.00), (1642205888654, float('nan')),
        (1642205899657, float('nan')), (1642205899970, 55.00),
        (1642206338180, float('nan')), (1642206353652, 56.45),
        (1642206853451, float('nan')), (1642207353652, 80.45)
    ], ["timestamp", "value"])
    
    # replace NaN value by Nulls
    df = df.replace(float("nan"), None, ["value"])
    

    You can use some window functions (last, first) to get next and previous non null values for each row and calculate the time gap like this:

    from pyspark.sql import functions as F, Window
    
    w1 = Window.orderBy("timestamp").rowsBetween(1, Window.unboundedFollowing)
    w2 = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, -1)
    
    df = (
        df.withColumn("rn", F.row_number().over(Window.orderBy("timestamp")))
        .withColumn("next_val", F.first("value", ignorenulls=True).over(w1))
        .withColumn("next_rn", F.first(F.when(F.col("value").isNotNull(), F.col("rn")), ignorenulls=True).over(w1))
        .withColumn("prev_val", F.last("value", ignorenulls=True).over(w2))
        .withColumn("prev_rn", F.last(F.when(F.col("value").isNotNull(), F.col("rn")), ignorenulls=True).over(w2))
        .withColumn("timegap_to_next", F.when(F.col("value").isNotNull(), F.min(F.when(F.col("value").isNotNull(), F.col("timestamp"))).over(w1) - F.col("timestamp")))
    )
    

    Now, you can do the linear interpolation of column value depending on your threshold using when expression:

    w3 = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, Window.currentRow)
    
    df = df.withColumn(
        "value",
        F.coalesce(
            "value",
            F.when(
                F.last("timegap_to_next", ignorenulls=True).over(w3) < 500000,
                (F.col("prev_val") + 
                ((F.col("next_val") - F.col("prev_val"))/ 
                (F.col("next_timestamp") - F.col("prev_next_timestamp"))
                * (F.col("timestamp") - F.col("prev_next_timestamp")
                        )
                    )
                )
            )
        )
    ).select("timestamp", "value", "timegap_to_next")
    
    df.show()
    
    #+-------------+------+---------------+
    #|    timestamp| value|timegap_to_next|
    #+-------------+------+---------------+
    #|1642205833225|  58.0|          66745|
    #|1642205888654|  56.0|           null|
    #|1642205899657|  57.0|           null|
    #|1642205899970|  55.0|         453682|
    #|1642206338180|55.725|           null|
    #|1642206353652| 56.45|        1000000|
    #|1642206853451|  null|           null|
    #|1642207353652| 80.45|           null|
    #+-------------+------+---------------+