Search code examples
pythondataframepysparkaggregationshift

Difference between a Row and its lead by 3 Rows in a PySpark DataFrame


I have a CSV file which has been imported as a dataframe through the following codes:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.read.csv("name of file.csv", inferSchema = True, header = True)
df.show()

output

    +-----+------+-----+
    |col1 | col2 | col3|
    +-----+------+-----+    
    |  A  |  2   |  4  |
    +-----+------+-----+    
    |  A  |  4   |  5  | 
    +-----+------+-----+    
    |  A  |  7   |  7  | 
    +-----+------+-----+    
    |  A  |  3   |  8  | 
    +-----+------+-----+    
    |  A  |  7   |  3  | 
    +-----+------+-----+    
    |  B  |  8   |  9  |
    +-----+------+-----+    
    |  B  |  10  |  10 | 
    +-----+------+-----+    
    |  B  |  8   |  9  |
    +-----+------+-----+    
    |  B  |  20  |  15 |
    +-----+------+-----+

I want to create another col4 which contains col2[n+3]/col2-1 for each group in col1 separately.

The output should be

   +-----+------+-----+-----+
   |col1 | col2 | col3| col4|
   +-----+------+-----+-----+    
   | A   |    2 |   4 |  0.5|  #(3/2-1)
   +-----+------+-----+-----+    
   | A   |    4 |   5 | 0.75| #(7/4-1)
   +-----+------+-----+-----+    
   | A   |    7 |   7 |  NA |
   +-----+------+-----+-----+    
   | A   |    3 |   8 |  NA |
   +-----+------+-----+-----+    
   | A   |    7 |   3 |  NA |
   +-----+------+-----+-----+    
   | B   |    8 |   9 | 1.5 |
   +-----+------+-----+-----+    
   | B   |   10 |  10 |  NA |
   +-----+------+-----+-----+    
   | B   |    8 |  9  |  NA |
   +-----+------+-----+-----+    
   | B   |   20 |  15 |  NA |
   +-----+------+-----+-----+

I know how to do this in pandas but I am not sure how to do some computation on the grouped column in PySpark.

At the moment, my PySpark version is 2.4


Solution

  • My Spark version is 2.2. lead() and Window() have been used. For reference.

    from pyspark.sql.window import Window
    from pyspark.sql.functions import lead, col    
    my_window = Window.partitionBy('col1').orderBy('col1')
    df = df.withColumn('col2_lead_3', lead(col('col2'),3).over(my_window))\
           .withColumn('col4',(col('col2_lead_3')/col('col2'))-1).drop('col2_lead_3')
    df.show()
    +----+----+----+----+
    |col1|col2|col3|col4|
    +----+----+----+----+
    |   B|   8|   9| 1.5|
    |   B|  10|  10|null|
    |   B|   8|   9|null|
    |   B|  20|  15|null|
    |   A|   2|   4| 0.5|
    |   A|   4|   5|0.75|
    |   A|   7|   7|null|
    |   A|   3|   8|null|
    |   A|   7|   3|null|
    +----+----+----+----+