Search code examples
apache-sparkpysparknull

Back and forward fill null values in a Spark Dataframe using pyspark


I have a Spark dataframe where I have to create a window partition column ("desired_output"). This column has to back fill not-null values if there is no not-null first in the sort order value, and forward fill the other non-null values.

Here is a sample Spark dataframe with the desired output:

columns = ['user_id', 'date', 'desired_outcome']
data = [\
        ('1', None,         '2022-01-05'),\
        ('1', None,         '2022-01-05'),\
        ('1', '2022-01-05', '2022-01-05'),\
        ('1', None,         '2022-01-05'),\
        ('1', None,         '2022-01-05'),\
        ('2', None,         '2022-01-07'),\
        ('2', None,         '2022-01-07'),\
        ('2','2022-01-07',  '2022-01-07'),\
        ('2',None,          '2022-01-07'),\
        ('2','2022-01-09',  '2022-01-09'),\
        ('2',None,          '2022-01-09'),\
        ('3','2022-01-01',  '2022-01-01'),\
        ('3',None,          '2022-01-01'),\
        ('3',None,          '2022-01-01'),\
        ('3','2022-01-04',  '2022-01-04'),\
        ('3',None,          '2022-01-04'),\
        ('3',None,          '2022-01-04')]

sample_df = spark.createDataFrame(data, columns)

[UPDATE]

Based on the proposal solution of @user238607 (see the first of the answers below) I should have add this update outlining that the desired output doesn't depend on the date column values' chronological (or any other) order - the desired output depends only on the original sequence of the records within groups with the same user_id (this value column can be in general case of any type: numeric, string etc.). I have used the proposed code with some editions with another input and the calculated desired output isn't quite correct - the incorrect calculated values are marked with three dashes in the correct column:

from pyspark.sql import Window
from pyspark import SQLContext
from pyspark.sql.functions import *
import pyspark.sql.functions as F


sc = spark #SparkContext('local')
sqlContext = SQLContext(sc)


data1 = [
        ('1', None,         '2022-02-12'),
        ('1', None,         '2022-02-12'),
        ('1', '2022-02-12', '2022-02-12'),
        ('1', None,         '2022-02-12'),
        ('1', None,         '2022-02-12'),
        ('2', None,         '2022-04-09'),
        ('2', None,         '2022-04-09'),
        ('2','2022-04-09',  '2022-04-09'),
        ('2',None,          '2022-04-09'),
        ('2','2022-01-07',  '2022-01-07'),
        ('2',None,          '2022-01-07'),
        ('3','2022-11-05',  '2022-11-05'),
        ('3',None,          '2022-11-05'),
        ('3',None,          '2022-11-05'),
        ('3','2022-01-04',  '2022-01-04'),
        ('3',None,          '2022-01-04'),
        ('3',None,          '2022-01-04'),
        ('3','2022-04-15',  '2022-04-15'),
        ('3',None,          '2022-04-15'),
        
        ]

columns = ['row_id', 'user_id', 'date', 'desired_outcome_given']

## following is done so that we have order of the rows.
data2 = [ (index, item[0], item[1], item[2]) for index, item in enumerate(data1) ]

df1 = sqlContext.createDataFrame(data=data2, schema=columns)

print("Given dataframe")
df1.show(n=10, truncate=False)

window_min_spec = Window.partitionBy("user_id").orderBy(F.col("row_id").asc()).rowsBetween(0, Window.unboundedFollowing)
window_max_spec = Window.partitionBy("user_id").orderBy(F.col("row_id").asc()).rowsBetween(Window.unboundedPreceding, 0)

df1 = df1.withColumn("first_date", F.first("date", ignorenulls=True).over(window_min_spec))
df1 = df1.withColumn("last_date", F.last("date", ignorenulls=True).over(window_max_spec))
print("Calculated first and last dates")
df1.show(truncate=False)
print("Final dataframe")
#output = df1.select('row_id', 'user_id', 'date').withColumn("desired_outcome_calculated", F.least(*["min_date", "max_date"])).select("desired_outcome_given", "desired_outcome_calculated")
output = df1.\
            withColumn("desired_outcome_calculated", F.least(*["first_date", "last_date"]))\
            .withColumn("correct", F.when(F.col("desired_outcome_given") == F.col("desired_outcome_calculated"),F.lit('true')).otherwise('---'))\
            .select('row_id', 'user_id', 'date', "desired_outcome_given", "desired_outcome_calculated", "correct")
output.show(truncate=False)

Output:


Given dataframe
+------+-------+----------+---------------------+
|row_id|user_id|date      |desired_outcome_given|
+------+-------+----------+---------------------+
|0     |1      |NULL      |2022-02-12           |
|1     |1      |NULL      |2022-02-12           |
|2     |1      |2022-02-12|2022-02-12           |
|3     |1      |NULL      |2022-02-12           |
|4     |1      |NULL      |2022-02-12           |
|5     |2      |NULL      |2022-04-09           |
|6     |2      |NULL      |2022-04-09           |
|7     |2      |2022-04-09|2022-04-09           |
|8     |2      |NULL      |2022-04-09           |
|9     |2      |2022-01-07|2022-01-07           |
+------+-------+----------+---------------------+
only showing top 10 rows

Calculated first and last dates
+------+-------+----------+---------------------+----------+----------+
|row_id|user_id|date      |desired_outcome_given|first_date|last_date |
+------+-------+----------+---------------------+----------+----------+
|0     |1      |NULL      |2022-02-12           |2022-02-12|NULL      |
|1     |1      |NULL      |2022-02-12           |2022-02-12|NULL      |
|2     |1      |2022-02-12|2022-02-12           |2022-02-12|2022-02-12|
|3     |1      |NULL      |2022-02-12           |NULL      |2022-02-12|
|4     |1      |NULL      |2022-02-12           |NULL      |2022-02-12|
|5     |2      |NULL      |2022-04-09           |2022-04-09|NULL      |
|6     |2      |NULL      |2022-04-09           |2022-04-09|NULL      |
|7     |2      |2022-04-09|2022-04-09           |2022-04-09|2022-04-09|
|8     |2      |NULL      |2022-04-09           |2022-01-07|2022-04-09|
|9     |2      |2022-01-07|2022-01-07           |2022-01-07|2022-01-07|
|10    |2      |NULL      |2022-01-07           |NULL      |2022-01-07|
|11    |3      |2022-11-05|2022-11-05           |2022-11-05|2022-11-05|
|12    |3      |NULL      |2022-11-05           |2022-01-04|2022-11-05|
|13    |3      |NULL      |2022-11-05           |2022-01-04|2022-11-05|
|14    |3      |2022-01-04|2022-01-04           |2022-01-04|2022-01-04|
|15    |3      |NULL      |2022-01-04           |2022-04-15|2022-01-04|
|16    |3      |NULL      |2022-01-04           |2022-04-15|2022-01-04|
|17    |3      |2022-04-15|2022-04-15           |2022-04-15|2022-04-15|
|18    |3      |NULL      |2022-04-15           |NULL      |2022-04-15|
+------+-------+----------+---------------------+----------+----------+

Final dataframe
+------+-------+----------+---------------------+--------------------------+-------+
|row_id|user_id|date      |desired_outcome_given|desired_outcome_calculated|correct|
+------+-------+----------+---------------------+--------------------------+-------+
|0     |1      |NULL      |2022-02-12           |2022-02-12                |true   |
|1     |1      |NULL      |2022-02-12           |2022-02-12                |true   |
|2     |1      |2022-02-12|2022-02-12           |2022-02-12                |true   |
|3     |1      |NULL      |2022-02-12           |2022-02-12                |true   |
|4     |1      |NULL      |2022-02-12           |2022-02-12                |true   |
|5     |2      |NULL      |2022-04-09           |2022-04-09                |true   |
|6     |2      |NULL      |2022-04-09           |2022-04-09                |true   |
|7     |2      |2022-04-09|2022-04-09           |2022-04-09                |true   |
|8     |2      |NULL      |2022-04-09           |2022-01-07                |---    |
|9     |2      |2022-01-07|2022-01-07           |2022-01-07                |true   |
|10    |2      |NULL      |2022-01-07           |2022-01-07                |true   |
|11    |3      |2022-11-05|2022-11-05           |2022-11-05                |true   |
|12    |3      |NULL      |2022-11-05           |2022-01-04                |---    |
|13    |3      |NULL      |2022-11-05           |2022-01-04                |---    |
|14    |3      |2022-01-04|2022-01-04           |2022-01-04                |true   |
|15    |3      |NULL      |2022-01-04           |2022-01-04                |true   |
|16    |3      |NULL      |2022-01-04           |2022-01-04                |true   |
|17    |3      |2022-04-15|2022-04-15           |2022-04-15                |true   |
|18    |3      |NULL      |2022-04-15           |2022-04-15                |true   |
+------+-------+----------+---------------------+--------------------------+-------+


Solution

  • You can do something like this.

    from pyspark.sql import Window
    from pyspark import SQLContext
    from pyspark.sql.functions import *
    import pyspark.sql.functions as F
    
    
    sc = SparkContext('local')
    sqlContext = SQLContext(sc)
    
    
    data1 = [
            ('1', None,         '2022-01-05'),
            ('1', None,         '2022-01-05'),
            ('1', '2022-01-05', '2022-01-05'),
            ('1', None,         '2022-01-05'),
            ('1', None,         '2022-01-05'),
            ('2', None,         '2022-01-07'),
            ('2', None,         '2022-01-07'),
            ('2','2022-01-07',  '2022-01-07'),
            ('2',None,          '2022-01-07'),
            ('2','2022-01-09',  '2022-01-09'),
            ('2',None,          '2022-01-09'),
            ('3','2022-01-01',  '2022-01-01'),
            ('3',None,          '2022-01-01'),
            ('3',None,          '2022-01-01'),
            ('3','2022-01-04',  '2022-01-04'),
            ('3',None,          '2022-01-04'),
            ('3',None,          '2022-01-04')]
    
    columns = ['row_id', 'user_id', 'date', 'desired_outcome_given']
    
    ## following is done so that we have order of the rows.
    data2 = [ (index, item[0], item[1], item[2]) for index, item in enumerate(data1) ]
    
    df1 = sqlContext.createDataFrame(data=data2, schema=columns)
    
    print("Given dataframe")
    df1.show(n=10, truncate=False)
    
    window_min_spec = Window.partitionBy("user_id").orderBy(F.col("row_id").asc()).rowsBetween(0, Window.unboundedFollowing)
    window_max_spec = Window.partitionBy("user_id").orderBy(F.col("row_id").asc()).rowsBetween(Window.unboundedPreceding, 0)
    
    df1 = df1.withColumn("min_date", F.first("date", ignorenulls=True).over(window_min_spec))
    df1 = df1.withColumn("max_date", F.last("date", ignorenulls=True).over(window_max_spec))
    print("Calculated min and max")
    df1.show(truncate=False)
    print("Final dataframe")
    output = df1.withColumn("desired_outcome_calculated", F.least(*["min_date", "max_date"])).select("desired_outcome_given", "desired_outcome_calculated")
    output.show(truncate=False)
    

    Output :

    Given dataframe
    +------+-------+----------+---------------------+
    |row_id|user_id|date      |desired_outcome_given|
    +------+-------+----------+---------------------+
    |0     |1      |null      |2022-01-05           |
    |1     |1      |null      |2022-01-05           |
    |2     |1      |2022-01-05|2022-01-05           |
    |3     |1      |null      |2022-01-05           |
    |4     |1      |null      |2022-01-05           |
    |5     |2      |null      |2022-01-07           |
    |6     |2      |null      |2022-01-07           |
    |7     |2      |2022-01-07|2022-01-07           |
    |8     |2      |null      |2022-01-07           |
    |9     |2      |2022-01-09|2022-01-09           |
    +------+-------+----------+---------------------+
    only showing top 10 rows
    
    Calculated min and max
    +------+-------+----------+---------------------+----------+----------+
    |row_id|user_id|date      |desired_outcome_given|min_date  |max_date  |
    +------+-------+----------+---------------------+----------+----------+
    |0     |1      |null      |2022-01-05           |2022-01-05|null      |
    |1     |1      |null      |2022-01-05           |2022-01-05|null      |
    |2     |1      |2022-01-05|2022-01-05           |2022-01-05|2022-01-05|
    |3     |1      |null      |2022-01-05           |null      |2022-01-05|
    |4     |1      |null      |2022-01-05           |null      |2022-01-05|
    |5     |2      |null      |2022-01-07           |2022-01-07|null      |
    |6     |2      |null      |2022-01-07           |2022-01-07|null      |
    |7     |2      |2022-01-07|2022-01-07           |2022-01-07|2022-01-07|
    |8     |2      |null      |2022-01-07           |2022-01-09|2022-01-07|
    |9     |2      |2022-01-09|2022-01-09           |2022-01-09|2022-01-09|
    |10    |2      |null      |2022-01-09           |null      |2022-01-09|
    |11    |3      |2022-01-01|2022-01-01           |2022-01-01|2022-01-01|
    |12    |3      |null      |2022-01-01           |2022-01-04|2022-01-01|
    |13    |3      |null      |2022-01-01           |2022-01-04|2022-01-01|
    |14    |3      |2022-01-04|2022-01-04           |2022-01-04|2022-01-04|
    |15    |3      |null      |2022-01-04           |null      |2022-01-04|
    |16    |3      |null      |2022-01-04           |null      |2022-01-04|
    +------+-------+----------+---------------------+----------+----------+
    
    Final dataframe
    +---------------------+--------------------------+
    |desired_outcome_given|desired_outcome_calculated|
    +---------------------+--------------------------+
    |2022-01-05           |2022-01-05                |
    |2022-01-05           |2022-01-05                |
    |2022-01-05           |2022-01-05                |
    |2022-01-05           |2022-01-05                |
    |2022-01-05           |2022-01-05                |
    |2022-01-07           |2022-01-07                |
    |2022-01-07           |2022-01-07                |
    |2022-01-07           |2022-01-07                |
    |2022-01-07           |2022-01-07                |
    |2022-01-09           |2022-01-09                |
    |2022-01-09           |2022-01-09                |
    |2022-01-01           |2022-01-01                |
    |2022-01-01           |2022-01-01                |
    |2022-01-01           |2022-01-01                |
    |2022-01-04           |2022-01-04                |
    |2022-01-04           |2022-01-04                |
    |2022-01-04           |2022-01-04                |
    +---------------------+--------------------------+
    

    EDIT : As per new update, here's how I would do it.

    columns = ['row_id', 'user_id', 'date', 'desired_outcome_given']
    
    ## following is done so that we have order of the rows.
    data2 = [ (index, item[0], item[1], item[2]) for index, item in enumerate(data1) ]
    
    df1 = sqlContext.createDataFrame(data=data2, schema=columns)
    
    print("Given dataframe")
    df1.show(n=10, truncate=False)
    
    window_fillup_notnull_spec = Window.partitionBy("user_id").orderBy(F.col("row_id").asc()).rowsBetween(Window.unboundedPreceding, 0 )
    window_fillbelow_notnull_spec = Window.partitionBy("user_id").orderBy(F.col("row_id").asc()).rowsBetween( 0, Window.unboundedFollowing)
    
    df1 = df1.withColumn("up_values", F.last("date", ignorenulls=True).over(window_fillup_notnull_spec))
    df1 = df1.withColumn("below_values", F.first("date", ignorenulls=True).over(window_fillbelow_notnull_spec))
    print("Calculated min and max")
    df1.show(truncate=False)
    print("Final dataframe")
    output = df1.withColumn("desired_outcome_calculated", F.when(  F.col("up_values").isNotNull(), F.col("up_values"))
                                                                 .otherwise(F.col("below_values")) ).select("desired_outcome_given", "desired_outcome_calculated")
    output.show(truncate=False)
    

    Output is as follows :

    Final dataframe
    +---------------------+--------------------------+
    |desired_outcome_given|desired_outcome_calculated|
    +---------------------+--------------------------+
    |2022-02-12           |2022-02-12                |
    |2022-02-12           |2022-02-12                |
    |2022-02-12           |2022-02-12                |
    |2022-02-12           |2022-02-12                |
    |2022-02-12           |2022-02-12                |
    |2022-04-09           |2022-04-09                |
    |2022-04-09           |2022-04-09                |
    |2022-04-09           |2022-04-09                |
    |2022-04-09           |2022-04-09                |
    |2022-01-07           |2022-01-07                |
    |2022-01-07           |2022-01-07                |
    |2022-11-05           |2022-11-05                |
    |2022-11-05           |2022-11-05                |
    |2022-11-05           |2022-11-05                |
    |2022-01-04           |2022-01-04                |
    |2022-01-04           |2022-01-04                |
    |2022-01-04           |2022-01-04                |
    |2022-04-15           |2022-04-15                |
    |2022-04-15           |2022-04-15                |
    +---------------------+--------------------------+