Search code examples
pythondataframeapache-sparkpysparksampling

Sample Pyspark DataFrame by date with number of days from last entry between each sample


Given a DataFrame:

import datetime
from pyspark.sql import Row

dataframe_rows = {
    Row(id = "A", date = datetime.datetime(2015, 1, 18)),
    Row(id = "A", date = datetime.datetime(2015, 2, 21)),
    Row(id = "A", date = datetime.datetime(2015, 2, 22)),
    Row(id = "A", date = datetime.datetime(2015, 6, 30)),
    Row(id = "A", date = datetime.datetime(2017, 12, 31)),
    Row(id = "B", date = datetime.datetime(2019, 1, 18)),
    Row(id = "B", date = datetime.datetime(2019, 1, 21)),
    Row(id = "B", date = datetime.datetime(2019, 2, 22)),
    Row(id = "B", date = datetime.datetime(2019, 2, 28)),
    Row(id = "B", date = datetime.datetime(2019, 12, 13)),
}

df_example = spark.createDataFrame(dataframe_rows).orderBy(["id", "date"], ascending=[1, 1])

So

df_example.show()

yields

+---+-------------------+
| id|               date|
+---+-------------------+
|  A|2015-01-18 00:00:00|
|  A|2015-02-21 00:00:00|
|  A|2015-02-22 00:00:00|
|  A|2015-06-30 00:00:00|
|  A|2017-12-31 00:00:00|
|  B|2019-01-18 00:00:00|
|  B|2019-01-21 00:00:00|
|  B|2019-02-22 00:00:00|
|  B|2019-02-28 00:00:00|
|  B|2019-12-13 00:00:00|
+---+-------------------+

I want a function that will sample rows from this DataFrame such that a specified number of days are between each sample and such that the last date per id is the last date per id in the sampled DataFrame.

For example, using 14 days between each sample,

+---+-------------------+
| id|               date|
+---+-------------------+
|  A|2015-01-18 00:00:00|
|  A|2015-02-22 00:00:00|
|  A|2015-06-30 00:00:00|
|  A|2017-12-31 00:00:00|
|  B|2019-01-18 00:00:00|
|  B|2019-02-28 00:00:00|
|  B|2019-12-13 00:00:00|
+---+-------------------+

Note the last date for each id is the same as it was in the original DataFrame.

Edit: The solution below works with the original DataFrame that I provided but if I change it

from pyspark.sql import Row

dataframe_rows = {
    Row(id = "A", date = datetime.datetime(2000, 11, 12)),
    Row(id = "A", date = datetime.datetime(2000, 12, 13)),
    Row(id = "A", date = datetime.datetime(2000, 12, 29)),
    Row(id = "A", date = datetime.datetime(2000, 12, 30)),
    Row(id = "A", date = datetime.datetime(2000, 12, 31)),
    Row(id = "B", date = datetime.datetime(2002, 2, 18)),
    Row(id = "B", date = datetime.datetime(2002, 2, 21)),
    Row(id = "B", date = datetime.datetime(2002, 2, 27)),
    Row(id = "B", date = datetime.datetime(2002, 2, 28)),
    Row(id = "B", date = datetime.datetime(2002, 12, 13)),
}

df_example = spark.createDataFrame(dataframe_rows).orderBy(["id", "date"], ascending=[1, 1])
df_example.show()

yielding

+---+-------------------+
| id|               date|
+---+-------------------+
|  A|2000-11-12 00:00:00|
|  A|2000-12-13 00:00:00|
|  A|2000-12-29 00:00:00|
|  A|2000-12-30 00:00:00|
|  A|2000-12-31 00:00:00|
|  B|2002-02-18 00:00:00|
|  B|2002-02-21 00:00:00|
|  B|2002-02-27 00:00:00|
|  B|2002-02-28 00:00:00|
|  B|2002-12-13 00:00:00|
+---+-------------------+

and apply the code I get

+---+----------+
| id|      date|
+---+----------+
|  A|2000-11-12|
|  A|2000-12-13|
|  A|2000-12-31|
|  B|2002-02-27|
|  B|2002-02-28|
|  B|2002-12-13|
+---+----------+

I am not sure why both February dates survive. I expected to see

+---+----------+
| id|      date|
+---+----------+
|  A|2000-11-12|
|  A|2000-12-13|
|  A|2000-12-31|
|  B|2002-02-28|
|  B|2002-12-13|
+---+----------+

Any ideas?


Solution

  • There's no direct timestamp resampling function in pyspark. However, I found a helper function from this blogpost which solves this issue. The function converts the timestamp to unix timestamp and aggregates the data based on your specified resampling interval.

    unix timestamp is a representation of timestamp in seconds starting from 1 January 1970. For example, 1 January 1970 00:00:00 = 0 seconds, 1 January 1970 01:00:00 = 3600 seconds, 2 January 1970 00:00:00 = 86400 seconds

    After making a resampled column using that function, you can continue by doing a .groupBy() and aggration using F.last() from pyspark.

    Resampling Function

    Edit (May 19th, 2021): Added offset

    import pyspark.sql.functions as F
    
    def resample(column, agg_interval=900, offset=0, time_format='yyyy-MM-dd HH:mm:ss'):
        if type(column)==str:
            column = F.col(column)
    
        # Convert the timestamp to unix timestamp format.
        # Unix timestamp = number of seconds since 00:00:00 UTC, 1 January 1970.
        col_ut =  F.unix_timestamp(column, format=time_format)
    
        # Divide the time into dicrete intervals, by rounding. 
        col_ut_agg =  F.floor( (col_ut + offset) / agg_interval) * agg_interval 
    
        # Convert to and return a human readable timestamp
        return F.from_unixtime(col_ut_agg)
    

    Resampling without offset

    # 14 days = 60 seconds * 60 minutes * 24 hours * 14 days
    df_example\
        .withColumn('date_resampled', resample(df_example.date, 60*60*24*14))\
        .groupBy('date_resampled')\
        .agg(F.last('id').alias('id'), F.last('date').alias('date'))\
        .orderBy(['id', 'date'])\
        .show()
    

    Output:

    +-------------------+---+-------------------+
    |     date_resampled| id|               date|
    +-------------------+---+-------------------+
    |2000-11-09 00:00:00|  A|2000-11-12 00:00:00|
    |2000-12-07 00:00:00|  A|2000-12-13 00:00:00|
    |2000-12-21 00:00:00|  A|2000-12-31 00:00:00|
    |2002-02-14 00:00:00|  B|2002-02-27 00:00:00|
    |2002-02-28 00:00:00|  B|2002-02-28 00:00:00|
    |2002-12-05 00:00:00|  B|2002-12-13 00:00:00|
    +-------------------+---+-------------------+
    

    Resampling with offset

    If conducted without offset, resampling will start from 1 January 1970 00:00:00 GMT. With 1 day offset, resampling will start from 2 January 1970 00:00:00 GMT.

    # 14 days = 60 seconds * 60 minutes * 24 hours * 14 days
    # offset = 10 days
    df_example\
        .withColumn('date_resampled', resample(df_example.date, 60*60*24*14, 60*60*24*10))\
        .groupBy('date_resampled')\
        .agg(F.last('id').alias('id'), F.last('date').alias('date'))\
        .orderBy(['id', 'date'])\
        .show()
    

    Output:

    +-------------------+---+-------------------+
    |     date_resampled| id|               date|
    +-------------------+---+-------------------+
    |2000-11-09 00:00:00|  A|2000-11-12 00:00:00|
    |2000-12-21 00:00:00|  A|2000-12-13 00:00:00|
    |2001-01-04 00:00:00|  A|2000-12-31 00:00:00|
    |2002-02-28 00:00:00|  B|2002-02-28 00:00:00|
    |2002-12-19 00:00:00|  B|2002-12-13 00:00:00|
    +-------------------+---+-------------------+