Search code examples

How to apply conditions to groupby dataframe in PySpark

I have a dataframe like this:

ID   Transaction_time     Status     final_time
1     1981-01-12           hit    
1     1981-01-13           hit        
1     1981-01-14           good     1981-01-15   
1     1981-01-15           OK       1981-01-16
2     1981-01-06           good     1981-01-17
3     1981-01-07           hit      1981-01-16
4     1981-01-06           hit      
4     1981-01-07           good      
4     1981-01-08           good     1981-01-10

I would like to keep ID if:

  • Status has "hit" and "good"/"OK"
  • the final_time is not empty for the last Transaction_time

Then, I would like to extract:

  • id - the ID
  • status - the last Transaction_time
  • start_time - the Transaction_time when Status changes from "hit" to "good"
  • finish_time - the final_time at the last Transaction_time

For the above example, it would be:

id    status       start_time       finish_time
1     OK           1981-01-14       1981-01-16
4     good         1981-01-07       1981-01-10

How to do it in PySpark?


  • I mostly used window functions instead of groupby:

    w1 = Window.partitionBy('ID').orderBy(F.col('Transaction_time').desc())
    w2 = Window.partitionBy('ID').orderBy(F.col('final_time').desc())
    df2 = df1.withColumn('next_st', F.lag('Status', 1).over(w1)) \
             .withColumn('next_tt', F.lag('Transaction_time', 1).over(w1)) \
             .withColumn('max_tt', F.max('Transaction_time').over(w1)) \
             .withColumn('max_ft', F.max('final_time').over(w2))
    df3 = df2.join(df2.filter((F.col('Transaction_time') == F.col('max_tt')) & F.col('final_time').isNotNull()), 'ID', 'leftsemi')
    df4 = df3.filter((F.col('Status') == 'hit') & F.col('next_st').isin(['good', 'OK']))
    df5 = (
        .join(df1.alias('df1'), (df1.ID == df4.ID) & (F.col('df1.final_time') == F.col('df4.max_ft')))
    #  +---+------+----------+-----------+
    #  | id|status|start_time|finish_time|
    #  +---+------+----------+-----------+
    #  |  4|  good|1981-01-07| 1981-01-10|
    #  |  1|    OK|1981-01-14| 1981-01-16|
    #  +---+------+----------+-----------+


    from pyspark.sql import functions as F, Window

    Original dataset:

    data = [
    (1, '1981-01-12', 'hit', None),
    (1, '1981-01-13', 'hit', None),
    (1, '1981-01-14', 'good', '1981-01-15'),
    (1, '1981-01-15', 'OK', '1981-01-16'),
    (2, '1981-01-06', 'good', '1981-01-17'),
    (3, '1981-01-07', 'hit', '1981-01-16'),
    (4, '1981-01-06', 'hit', None),
    (4, '1981-01-07', 'good', None),
    (4, '1981-01-08', 'good', '1981-01-10')]
    df1 = spark.createDataFrame(data, ['ID', 'Transaction_time', 'Status', 'final_time'])
    df1 = df1.withColumn('Transaction_time', F.col('Transaction_time').cast('date')) \
             .withColumn('final_time', F.col('final_time').cast('date'))
    #  +---+----------------+------+----------+
    #  | ID|Transaction_time|Status|final_time|
    #  +---+----------------+------+----------+
    #  |  1|      1981-01-12|   hit|      null|
    #  |  1|      1981-01-13|   hit|      null|
    #  |  1|      1981-01-14|  good|1981-01-15|
    #  |  1|      1981-01-15|    OK|1981-01-16|
    #  |  2|      1981-01-06|  good|1981-01-17|
    #  |  3|      1981-01-07|   hit|1981-01-16|
    #  |  4|      1981-01-06|   hit|      null|
    #  |  4|      1981-01-07|  good|      null|
    #  |  4|      1981-01-08|  good|1981-01-10|
    #  +---+----------------+------+----------+

    Intermediate dfs:

    | ID|Transaction_time|Status|final_time|
    |  1|      1981-01-12|   hit|      null|
    |  1|      1981-01-13|   hit|      null|
    |  1|      1981-01-14|  good|1981-01-15|
    |  1|      1981-01-15|    OK|1981-01-16|
    |  2|      1981-01-06|  good|1981-01-17|
    |  3|      1981-01-07|   hit|1981-01-16|
    |  4|      1981-01-06|   hit|      null|
    |  4|      1981-01-07|  good|      null|
    |  4|      1981-01-08|  good|1981-01-10|
    | ID|Transaction_time|Status|final_time|next_st|   next_tt|    max_tt|    max_ft|
    |  1|      1981-01-15|    OK|1981-01-16|   null|      null|1981-01-15|1981-01-16|
    |  1|      1981-01-14|  good|1981-01-15|     OK|1981-01-15|1981-01-15|1981-01-16|
    |  1|      1981-01-13|   hit|      null|   good|1981-01-14|1981-01-15|1981-01-16|
    |  1|      1981-01-12|   hit|      null|    hit|1981-01-13|1981-01-15|1981-01-16|
    |  3|      1981-01-07|   hit|1981-01-16|   null|      null|1981-01-07|1981-01-16|
    |  2|      1981-01-06|  good|1981-01-17|   null|      null|1981-01-06|1981-01-17|
    |  4|      1981-01-08|  good|1981-01-10|   null|      null|1981-01-08|1981-01-10|
    |  4|      1981-01-07|  good|      null|   good|1981-01-08|1981-01-08|1981-01-10|
    |  4|      1981-01-06|   hit|      null|   good|1981-01-07|1981-01-08|1981-01-10|
    | ID|Transaction_time|Status|final_time|next_st|   next_tt|    max_tt|    max_ft|
    |  1|      1981-01-15|    OK|1981-01-16|   null|      null|1981-01-15|1981-01-16|
    |  1|      1981-01-14|  good|1981-01-15|     OK|1981-01-15|1981-01-15|1981-01-16|
    |  1|      1981-01-13|   hit|      null|   good|1981-01-14|1981-01-15|1981-01-16|
    |  1|      1981-01-12|   hit|      null|    hit|1981-01-13|1981-01-15|1981-01-16|
    |  3|      1981-01-07|   hit|1981-01-16|   null|      null|1981-01-07|1981-01-16|
    |  2|      1981-01-06|  good|1981-01-17|   null|      null|1981-01-06|1981-01-17|
    |  4|      1981-01-08|  good|1981-01-10|   null|      null|1981-01-08|1981-01-10|
    |  4|      1981-01-07|  good|      null|   good|1981-01-08|1981-01-08|1981-01-10|
    |  4|      1981-01-06|   hit|      null|   good|1981-01-07|1981-01-08|1981-01-10|
    | ID|Transaction_time|Status|final_time|next_st|   next_tt|    max_tt|    max_ft|
    |  1|      1981-01-13|   hit|      null|   good|1981-01-14|1981-01-15|1981-01-16|
    |  4|      1981-01-06|   hit|      null|   good|1981-01-07|1981-01-08|1981-01-10|
    | id|status|start_time|finish_time|
    |  4|  good|1981-01-07| 1981-01-10|
    |  1|    OK|1981-01-14| 1981-01-16|