Search code examples
apache-sparkduplicatesapache-spark-sqlinner-join

Dedupe rows in Spark DataFrame by most recent timestamp


I have a DataFrame with the following schema:

root
|- documentId
|- timestamp
|- anotherField

For example,

"d1", "2018-09-20 10:00:00", "blah1"
"d2", "2018-09-20 09:00:00", "blah2"
"d1", "2018-09-20 10:01:00", "blahnew"

Note that for the sake of understanding (and my convenience) I am showing the timestamp as a string. It is in fact a long representing milliseconds since epoch.

As seen here, there are duplicate rows (row 1 and 3) with the same documentId but different timestamp (and possibly different other fields). I want to dedupe and retain only the most recent (based on timestamp) row for each documentId.

A simple df.groupBy("documentId").agg(max("timestamp), ...) does not seem likely to work here because I don't know how to retain the other fields in the row corresponding to the one that satisfies max("timestamp").

So, I came up with a complicated way of doing this.

// first find the max timestamp corresponding to each documentId
val mostRecent = df
    .select("documentId", "timestamp")
      .groupBy("documentId")
        .agg(max("timestamp"))

// now join with the original df on timestamp to retain
val dedupedDf = df.join(mostRecent, Seq("documentId", "timestamp"), "inner")

This resulting dedupedDf should have only those rows which correspond to the most recent entry for each documentId.

Although this works, I don't feel this is the right (or efficient) approach, since I am using a join which seems needless.

How can I do it better? I am looking for pure "DataFrame" based solutions as opposed to RDD-based approaches (since DataBricks folks have repeatedly told us in a workshop to work with DataFrames and not RDDs).


Solution

  • See the below code helps your objective,

    val df = Seq(
      ("d1", "2018-09-20 10:00:00", "blah1"),
      ("d2", "2018-09-20 09:00:00", "blah2"),
      ("d1", "2018-09-20 10:01:00", "blahnew")
    ).toDF("documentId","timestamp","anotherField")
    
    import org.apache.spark.sql.functions.row_number
    import org.apache.spark.sql.expressions.Window
    
    val w = Window.partitionBy($"documentId").orderBy($"timestamp".desc)
    val Resultdf = df.withColumn("rownum", row_number.over(w))
         .where($"rownum" === 1).drop("rownum")
    
    Resultdf.show()
    

    input:

    +----------+-------------------+------------+
    |documentId|          timestamp|anotherField|
    +----------+-------------------+------------+
    |        d1|2018-09-20 10:00:00|       blah1|
    |        d2|2018-09-20 09:00:00|       blah2|
    |        d1|2018-09-20 10:01:00|     blahnew|
    +----------+-------------------+------------+
    

    output:

    +----------+-------------------+------------+
    |documentId|          timestamp|anotherField|
    +----------+-------------------+------------+
    |        d2|2018-09-20 09:00:00|       blah2|
    |        d1|2018-09-20 10:01:00|     blahnew|
    +----------+-------------------+------------+