Search code examples
sqlapache-spark-sqlwindow-functionsgaps-and-islands

How to get last value for every partition to impute missing value in spark SQL


I have a sample data where I want to impute the missing values. The rows where the data is missing is denoted by blank. Here is the sample data -

val my_df = spark.sql(s"""
select 1 as id, 1 as time_gmt, 'a' as pagename
union
select 1 as id, 2 as time_gmt, 'b' as pagename
union
select 1 as id, 3 as time_gmt, 'blank' as pagename
union
select 1 as id, 4 as time_gmt, 'blank' as pagename
union
select 1 as id, 5 as time_gmt, 'd' as pagename
union
select 2 as id, 1 as time_gmt, 'c' as pagename
union
select 2 as id, 2 as time_gmt, 'a' as pagename
union
select 2 as id, 3 as time_gmt, 'c' as pagename
union
select 2 as id, 4 as time_gmt, 'blank' as pagename
union
select 2 as id, 5 as time_gmt, 'd' as pagename
""")
my_df.createOrReplaceTempView("my_df")

scala> my_df.orderBy("id","time_gmt").show(false)
+---+--------+--------+
|id |time_gmt|pagename|
+---+--------+--------+
|1  |1       |a       |
|1  |2       |b       |
|1  |3       |blank   |
|1  |4       |blank   |
|1  |5       |d       |
|2  |1       |c       |
|2  |2       |a       |
|2  |3       |c       |
|2  |4       |blank   |
|2  |5       |d       |
+---+--------+--------+

As you can see, there are 2 blanks for data with id 1 and 1 blank for data with id 2. I want to fill in those values using the latest non-blank value observed for each ID, ordered by time_gmt column. So my output would be -

+---+--------+--------+----------------+
|id |time_gmt|pagename|pagename_imputed|
+---+--------+--------+----------------+
|1  |1       |a       | a              | 
|1  |2       |b       | b              | 
|1  |3       |blank   | b              | 
|1  |4       |blank   | b              |
|1  |5       |d       | d              | 
|2  |1       |c       | c              | 
|2  |2       |a       | a              | 
|2  |3       |c       | c              | 
|2  |4       |blank   | c              | 
|2  |5       |d       | d              | 
+---+--------+--------+----------------+

How can I do this in spark SQL ?

NOTE - the blanks can appear multiple times for every partition after non-blank values.


Solution

  • One option uses window functions. The idea is to define groups of record, where "blank" records will belong to the same group as the last non-blank.

    Assuming that by blank you mean null, we can define the groups with a window count:

    select id, time_gmt, 
        max(pagename) over(partition by id, grp) as pagename
    from (
        select t.*, 
            count(pagename) over(partition by id order by time_gmt) as grp
        from mytable t
    ) t
    

    If you really mean string 'blank', then:

    select id, time_gmt, 
        max(case when pagename <> 'blank' then pagename end) over(partition by id, grp) as pagename
    from (
        select t.*, 
            sum(case when pagename = 'blank' then 0 else 1 end) over(partition by id order by time_gmt) as grp
        from mytable t
    ) t