Search code examples
apache-sparkpysparkwindow-functions

Pyspark window function with conditions to round number of travelers


I am using Pyspark and I would like to create a function which performs the following operation:

Given data describing the transactions of train users:

+----+----------+--------+-----+
|date|total_trav|num_trav|order|
+----+----------+--------+-----+
|   1|         9|     2.7|    1|
|   1|         9|     1.3|    2|
|   1|         9|     1.3|    3|
|   1|         9|     1.3|    4|
|   1|         9|     1.2|    5|
|   1|         9|     1.1|    6|
|   2|         9|     2.7|    1|
|   2|         9|     1.3|    2|
|   2|         9|     1.3|    3|
|   2|         9|     1.3|    4|
|   2|         9|     1.2|    5|
|   2|         9|     1.1|    6|
+----+----------+--------+-----+

I would like to round the numbers of the num_trav column based on the order given in the order column, while grouping by date to obtain the trav_res column. The logic behind it would be something like:

  • We group the data by date
  • For each grouped data (where date=1 and date=2) we have to round numbers to the ceil (ceil(num_trav)) always (it doesn´t matter their value, always rounded to the ceil). But taking into account that we have a maximum amount of travelers by group (total_trav) which in this case, for both groups would be 9.
  • This is where order column takes place. You need to start rounding in the order given by that column and checking the amount of travelers you have left for that group.

For example, let's consider this result dataframe and see how the trav_res column is formed:

+----+----------+--------+-----+--------+
|date|total_trav|num_trav|order|trav_res|
+----+----------+--------+-----+--------+
|   1|         9|     2.7|    1|       3|
|   1|         9|     1.3|    2|       2|
|   1|         9|     1.3|    3|       2|
|   1|         9|     1.3|    4|       2|
|   1|         9|     1.2|    5|       0|
|   1|         9|     1.1|    6|       0|
|   2|         9|     2.7|    1|       3|
|   2|         9|     1.3|    2|       2|
|   2|         9|     1.3|    3|       2|
|   2|         9|     1.3|    4|       2|
|   2|         9|     1.2|    5|       0|
|   2|         9|     1.1|    6|       0|
+----+----------+--------+-----+--------+

In the example above, when you group by date, you will have 2 groups which the max amount of travelers is 9 (total_trav column). For group 1 for example yo will start rounding the num_trav=2.7 to 3 (trav_res column), then the num_trav=1.3 to 2, then num_trav=1.3 to 2, the num_trav=1.3 to 2 (this is following the order given), and then for the next ones you have no travelers left, so it doesn't really matter the number they have as there are no travelers left, so they will get trav_res=0 in both cases.

I have tried some udf functions, but thy seem not to do the job.


Solution

  • The solution is based on @AnnaK. answer, with a little addition to it. This way it takes into accountthat the total number of travelers (total_trav) has to be used, not more, not less.

    # create ceiling column
    df = df_j_test_res.withColumn("num_trav_ceil", F.ceil("num_trav"))
    
    # create cumulative sum column
    w = Window.partitionBy("date").orderBy("order")
    df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))
    
    # impose 0 in trav_res when cumsum exceeds total_trav
    df = (df
      .withColumn("trav_res", 
                   F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
                   F.col("num_trav_ceil")
                         ).when((F.col('num_trav_ceil_cumsum')-F.col('total_trav')>0) & ((F.col('num_trav_ceil_cumsum')-F.col('total_trav')<=1)),
                          1)
                  .otherwise(0))
      .select("date", "total_trav", "num_trav", "order", "trav_res"))