Search code examples
pythonpython-3.xdataframepysparkwindow-functions

PySpark: Using Window Functions to roll-up dataframe


I have a dataframe my_df that contains 4 columns:

+----------------+---------------+--------+---------+
|         user_id|         domain|isp_flag|frequency|
+----------------+---------------+--------+---------+
|            josh|     wanadoo.fr|       1|       15|
|            josh|      random.it|       0|       12|
|        samantha|     wanadoo.fr|       1|       16|
|             bob|    eidsiva.net|       1|        5|
|             bob|      media.net|       0|        1|
|           dylan|    vodafone.it|       1|      448|
|           dylan|   somesite.net|       0|       20|
|           dylan|   yolosite.net|       0|       49|
|           dylan|      random.it|       0|        3|
|             don|    vodafone.it|       1|       39|
|             don|   popsugar.com|       0|       10|
|             don|      fabio.com|       1|       49|
+----------------+---------------+--------+---------+

This is what I'm planning to do-

Find all the user_id where the maximum frequency domain with isp_flag=0 has a frequency that is less than 25% of the maximum frequency domain with isp_flag=1.

So, in the example that I have above, my output_df would look like-

+----------------+---------------+--------+---------+
|         user_id|         domain|isp_flag|frequency|
+----------------+---------------+--------+---------+
|             bob|    eidsiva.net|       1|        5|
|             bob|      media.net|       0|        1|
|           dylan|    vodafone.it|       1|      448|
|           dylan|   yolosite.net|       0|       49|
|             don|      fabio.com|       1|       49|
|             don|   popsugar.com|       0|       10|
+----------------+---------------+--------+---------+

I believe I need window functions to do this, and so I tried the following to first find the maximum frequency domains for isp_flag=0 and isp_flag=1 respectively, for each of the user_id-

>>> win_1 = Window().partitionBy("user_id", "domain", "isp_flag").orderBy((col("frequency").desc()))
>>> final_df = my_df.select("*", rank().over(win_1).alias("rank")).filter(col("rank")==1)
>>> final_df.show(5)   # this just gives me the original dataframe back

What am I doing wrong here? How do I get to the final output_df I printed above?


Solution

  • IIUC, you can try the following: calculate the max_frequencies (max_0, max_1) for each user having isp_flag == 0 or 1 respectively. and then filter by condition max_0 < 0.25*max_1 and plus frequency in (max_1, max_0) to select only the records with maximum frequency.

    from pyspark.sql import Window, functions as F
    
    # set up the Window to calculate max_0 and max_1 for each user
    # having isp_flag = 0 and 1 respectively
    w1 = Window.partitionBy('user_id').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    
    df.withColumn('max_1', F.max(F.expr("IF(isp_flag==1, frequency, NULL)")).over(w1))\ 
      .withColumn('max_0', F.max(F.expr("IF(isp_flag==0, frequency, NULL)")).over(w1))\ 
      .where('max_0 < 0.25*max_1 AND frequency in (max_1, max_0)') \ 
      .show() 
    +-------+------------+--------+---------+-----+-----+                           
    |user_id|      domain|isp_flag|frequency|max_1|max_0|
    +-------+------------+--------+---------+-----+-----+
    |    don|popsugar.com|       0|       10|   49|   10|
    |    don|   fabio.com|       1|       49|   49|   10|
    |  dylan| vodafone.it|       1|      448|  448|   49|
    |  dylan|yolosite.net|       0|       49|  448|   49|
    |    bob| eidsiva.net|       1|        5|    5|    1|
    |    bob|   media.net|       0|        1|    5|    1|
    +-------+------------+--------+---------+-----+-----+
    

    Some Explanations per request:

    • the WindowSpec w1 is set to examine all records for the same user(partitionBy), so that the F.max() function will compare all rows based on the same user.

    • we use IF(isp_flag==1, frequency, NULL) to find frequency for rows having isp_flag==1, it returns NULL when isp_flag is not 1 and thus is skipped in F.max() function. this is an SQL expression and thus we need F.expr() function to run it.

    • F.max(...).over(w1) will take the max value of the result from executing the above SQL expression. this calculation is based on the Window w1.