I have two dataframes that I am trying to join together and I realized with my original implementation, I was getting undesirable results:
// plain_txns_df.show(false)
+------------+---------------+---------------+-----------+
|txn_date_at |account_number |merchant_name |txn_amount |
+------------+---------------+---------------+-----------+
|2020-04-08 |1234567 |Starbucks |2.02 |
|2020-04-14 |1234567 |Starbucks |2.86 |
|2020-04-14 |1234567 |Subway |12.02 |
|2020-04-14 |1234567 |Amazon |3.21 |
+------------+---------------+---------------+-----------+
// richer_txns_df.show(false)
+----------+-------+----------------------+-------------+
|TXN_DT |ACCT_NO|merch_name |merchant_city|
+----------+-------+----------------------+-------------+
|2020-04-08|1234567|Subway |Toronto |
|2020-04-14|1234567|Subway |Toronto |
+----------+-------+----------------------+-------------+
From the above two dataframes, my goal is to enrich the plain transactions with the merchant city, for transactions that are within a 7 day window (i.e. the transaction date from the richer transaction dataframe should be between the plain date and the plain date - 7days.
Initially I thought it was pretty straight forward and joined the data as so (range join I know):
spark.sql(
"""
| SELECT
| plain.txn_date_at,
| plain.account_number,
| plain.merchant_name,
| plain.txn_amount,
| richer.merchant_city
| FROM plain_txns_df plain
| LEFT JOIN richer_txns_df richer
| ON plain.account_number = richer.ACCT_NO
| AND plain.merchant_name = richer.merch_name
| AND richer.txn_date BETWEEN date_sub(plain.txn_date_at, 7) AND plain.txn_date_at
""".stripMargin)
However, when using the above, I get duplicate results for the April 14th transaction, because the merchant details and account details match the richer record from the 8th and fits within the date range:
+------------+---------------+---------------+-----------+-------------+
|txn_date_at |account_number |merchant_name |txn_amount |merchant_city|
+------------+---------------+---------------+-----------+-------------+
|2020-04-08 |1234567 |Starbucks |2.02 |Toronto |
|2020-04-14 |1234567 |Starbucks |2.86 |Toronto | // Apr-08 Richer record
|2020-04-14 |1234567 |Starbucks |2.86 |Toronto |
+------------+---------------+---------------+-----------+-------------+
Is there a way I can get just one record for each value in my plain DataFrame (i.e. get one record for the 14th in the above result set)? I tried running a distinct after the join, which solves this problem but I realize if there are two transactions on the same day for the same merchant, I would lose those.
I was thinking of moving the richer table to a subquery and then applying the date filter inside of that but I don't know how to pass the transaction date filter value into this query :(. Something like the following, but it doesn't recognize the plain transaction date:
spark.sql(
"""
| SELECT
| plain.txn_date_at,
| plain.account_number,
| plain.merchant_name,
| plain.txn_amount,
| richer2.merchant_city
| FROM plain_txns_df plain
| LEFT JOIN (
| SELECT ACCT_NO, merch_name from richer_txns_df
| WHERE txn_date BETWEEN date_sub(plain.txn_date_at, 7) AND plain.txn_date_at
| ) richer2
| ON plain.account_number = richer2.ACCT_NO
| AND plain.merchant_name = richer2.merch_name
""".stripMargin)
First thing that I think needs to be done is to create a unique key on plain_txns_df
which makes it possible to distinguish rows from one another when trying to aggregate/compare them.
import org.apache.spark.sql.functions._
plainDf.withColumn("id", monotonically_increasing_id())
With that you can proceed and do the first query you posted (plus id
column), which returns duplicates:
spark.sql("""
SELECT
plain.id,
plain.txn_date_at,
plain.account_number,
plain.merchant_name,
plain.txn_amount,
richer.merchant_city,
richer.txn_dt
FROM plain_txns_df plain
INNER JOIN richer_txns_df richer
ON plain.account_number = richer.acc_no
AND plain.merchant_name = richer.merch_name
AND richer.txn_dt BETWEEN date_sub(plain.txn_date_at, 7) AND plain.txn_date_at
""".stripMargin).createOrReplaceTempView("foo")
Next is deduplicating above dataframe by getting record with latest richer_txns_df.txn_dt
date for given id
.
spark.sql("""
SELECT
f1.txn_date_at,
f1.account_number,
f1.merchant_name,
f1.txn_amount,
f1.merchant_city
FROM foo f1
LEFT JOIN foo f2
ON f2.id = f1.id
AND f2.txn_dt > f1.txn_dt
WHERE f2.id IS NULL
""".stripMargin).show