Search code examples
apache-sparkapache-spark-sqlcorrelated-subquery

Spark SQL - Keep one result after join


I have two dataframes that I am trying to join together and I realized with my original implementation, I was getting undesirable results:

// plain_txns_df.show(false)
+------------+---------------+---------------+-----------+
|txn_date_at |account_number |merchant_name  |txn_amount |
+------------+---------------+---------------+-----------+
|2020-04-08  |1234567        |Starbucks      |2.02       |
|2020-04-14  |1234567        |Starbucks      |2.86       |
|2020-04-14  |1234567        |Subway         |12.02      |
|2020-04-14  |1234567        |Amazon         |3.21       |
+------------+---------------+---------------+-----------+
// richer_txns_df.show(false)
+----------+-------+----------------------+-------------+
|TXN_DT    |ACCT_NO|merch_name            |merchant_city|
+----------+-------+----------------------+-------------+
|2020-04-08|1234567|Subway                |Toronto      |
|2020-04-14|1234567|Subway                |Toronto      |
+----------+-------+----------------------+-------------+

From the above two dataframes, my goal is to enrich the plain transactions with the merchant city, for transactions that are within a 7 day window (i.e. the transaction date from the richer transaction dataframe should be between the plain date and the plain date - 7days.

Initially I thought it was pretty straight forward and joined the data as so (range join I know):

spark.sql(
    """
      | SELECT
      | plain.txn_date_at,
      | plain.account_number,
      | plain.merchant_name,
      | plain.txn_amount,
      | richer.merchant_city
      | FROM plain_txns_df plain
      | LEFT JOIN richer_txns_df richer
      | ON plain.account_number = richer.ACCT_NO
      | AND plain.merchant_name = richer.merch_name
      | AND richer.txn_date BETWEEN date_sub(plain.txn_date_at, 7) AND plain.txn_date_at
    """.stripMargin)

However, when using the above, I get duplicate results for the April 14th transaction, because the merchant details and account details match the richer record from the 8th and fits within the date range:

+------------+---------------+---------------+-----------+-------------+
|txn_date_at |account_number |merchant_name  |txn_amount |merchant_city|
+------------+---------------+---------------+-----------+-------------+
|2020-04-08  |1234567        |Starbucks      |2.02       |Toronto      |
|2020-04-14  |1234567        |Starbucks      |2.86       |Toronto      | // Apr-08 Richer record
|2020-04-14  |1234567        |Starbucks      |2.86       |Toronto      |
+------------+---------------+---------------+-----------+-------------+

Is there a way I can get just one record for each value in my plain DataFrame (i.e. get one record for the 14th in the above result set)? I tried running a distinct after the join, which solves this problem but I realize if there are two transactions on the same day for the same merchant, I would lose those.

I was thinking of moving the richer table to a subquery and then applying the date filter inside of that but I don't know how to pass the transaction date filter value into this query :(. Something like the following, but it doesn't recognize the plain transaction date:

spark.sql(
    """
      | SELECT
      | plain.txn_date_at,
      | plain.account_number,
      | plain.merchant_name,
      | plain.txn_amount,
      | richer2.merchant_city
      | FROM plain_txns_df plain
      | LEFT JOIN ( 
      |    SELECT ACCT_NO, merch_name from richer_txns_df
      |    WHERE txn_date BETWEEN date_sub(plain.txn_date_at, 7) AND plain.txn_date_at
      | ) richer2
      | ON plain.account_number = richer2.ACCT_NO
      | AND plain.merchant_name = richer2.merch_name
    """.stripMargin)

Solution

  • First thing that I think needs to be done is to create a unique key on plain_txns_df which makes it possible to distinguish rows from one another when trying to aggregate/compare them.

    import org.apache.spark.sql.functions._
    plainDf.withColumn("id", monotonically_increasing_id())
    

    With that you can proceed and do the first query you posted (plus id column), which returns duplicates:

    spark.sql("""
        SELECT
        plain.id,
        plain.txn_date_at,
        plain.account_number,
        plain.merchant_name,
        plain.txn_amount,
        richer.merchant_city,
        richer.txn_dt
        FROM plain_txns_df plain
        INNER JOIN richer_txns_df richer
        ON plain.account_number = richer.acc_no
        AND plain.merchant_name = richer.merch_name
        AND richer.txn_dt BETWEEN date_sub(plain.txn_date_at, 7) AND plain.txn_date_at
      """.stripMargin).createOrReplaceTempView("foo")
    

    Next is deduplicating above dataframe by getting record with latest richer_txns_df.txn_dt date for given id.

    spark.sql("""
        SELECT
        f1.txn_date_at,
        f1.account_number,
        f1.merchant_name,
        f1.txn_amount,
        f1.merchant_city
        FROM foo f1
        LEFT JOIN foo f2
        ON f2.id = f1.id
        AND f2.txn_dt > f1.txn_dt
        WHERE f2.id IS NULL
      """.stripMargin).show