Search code examples
pythonpyspark

How to split a pyspark dataframe taking a portion of data for each different id


I'm working with a pyspark dataframe (in Python) containing time series data. Data got a structure like this:

event_time  variable value   step   ID 
1456942945  var_a    123.4    1      id_1
1456931076  var_b    857.01   1      id_1
1456932268  var_b    871.74   1      id_1
1456940055  var_b    992.3    2      id_1
1456932781  var_c    861.3    2      id_1
1456937186  var_c    959.6    3      id_1
1456934746  var_d    0.12     4      id_1

1456942945  var_a    123.4    1      id_2
1456931076  var_b    847.01   1      id_2
1456932268  var_b    871.74   1      id_2
1456940055  var_b    932.3    2      id_2
1456932781  var_c    821.3    3      id_2
1456937186  var_c    969.6    4      id_2
1456934746  var_d    0.12     4      id_2

For each id i got each variable's value at a specific "step".

I need to subset this dataframe like this: for each id take all the rows corresponding to steps 1, 2, 3 and a portion of step 4 data starting from the first_event time value of step 4, let's say first 25%. This portioning is to be done with respect to event time.

I'm able to do it for a single id, after having subset the DF based on that id:

# single step partitioning 
threshold_value = DF.selectExpr(f"percentile_approx({"event_time"}, {0.25}) as threshold").collect()[0]["threshold"]

partitioned_df= DF.filter(col(column_name) <= threshold_value)

# First 3 steps
first_3_steps_df = DF.filter((col("step").isin([1,2,3])))

And then i would concat the partitioned_df and first_3_steps_df to obtain the desidered output for 1 specific id. I'm stuck at iterating this kind of partitioning for each id in DF without actually iterating that process for each id separately.

I'm also able to do it in pandas, but the DF is huge and i really need to stick to Pyspark, so no Pandas answers, please.


Solution

  • Group the data by ID and use percentile_approx as aggregation function to calculate the threshold for step=4. Then create a where clause with these values to filter the data:

    from pyspark.sql import functions as F
    
    df = ...
    
    threshold = df.where('step = 4') \
        .groupBy('ID') \
        .agg(F.percentile_approx('event_time', 0.25)) \
        .collect()
    
    threshold = [(r[0],r[1]) for r in threshold]
    
    whereStmt = 'step=1 or step=2 or step=3'
    
    for r in threshold:
        whereStmt = whereStmt + f' or (step=4 and ID={r[0]} and event_time<={r[1]})'
    
    df.where(F.expr(whereStmt)).show()