Search code examples
sqljinja2snowflake-cloud-data-platformdbt

SQL use LAG to keep looking for value until criteria is met


Essentially I have what I call floating rows that come into my table. These are rows that have a type = mod and they are associated with a customer_id, but not another transaction_id in the table. I need them associated with another transaction_id and not just hanging out by themselves. So, I want to create a mapping table that gives the floating transaction_id values and the previous_transaction_id they are associated with. I'm using LAG to get the previous_transaction_id and that works for some scenarios, but not all. Essentially, I'd like to tell LAG - "Hey LAG, if the value you found is one of these floating transaction_id values, keep stepping back until you find one that isn't" But I don't know how to do that.

I'm using Snowflake and have jinja and dbt at my disposal if there is a better way to do this.

Here is some mock data with the various scenarios I have and my current query:

with data as (
select 
    transaction_id,
    customer_id,
    transaction_date,
    amount,
    type
from (values 
    (1, 'a','03/24/2022'::date, 10, 'cat'),
    (1, 'a','03/24/2022'::date, 15, 'dog'),
    (1, 'a','03/24/2022'::date, 20, 'mouse'),
    (1, 'a','03/24/2022'::date, 30, 'rabbit'),
    (1, 'a','03/24/2022'::date, 5, 'squirrel'),
    (2, 'a','03/24/2022'::date, 4, 'mod'), -- floater
    (3, 'b','05/20/2022'::date, 100, 'cat'),
    (3, 'b','05/20/2022'::date, 150, 'dog'),
    (3, 'b','05/20/2022'::date, 200, 'mouse'),
    (3, 'b','05/20/2022'::date, 300, 'rabbit'),
    (3, 'b','05/20/2022'::date, 50, 'squirrel'),
    (4, 'b','07/20/2022'::date, 40, 'mod'), -- floater
    (5, 'c','02/02/2020'::date, 100, 'cat'),
    (5, 'c','02/02/2020'::date, 150, 'dog'),
    (5, 'c','02/02/2020'::date, 200, 'mouse'),
    (5, 'c','02/02/2020'::date, 300, 'rabbit'),
    (6, 'c','08/01/2020'::date, 50, 'mod'), -- floater
    (7, 'c','12/25/2020'::date, 40, 'mod'), -- floater
    (8, 'd','01/15/2021'::date, 10, 'cat'),
    (8, 'd','01/15/2021'::date, 15, 'dog'),
    (8, 'd','01/15/2021'::date, 20, 'mouse'),
    (8, 'd','01/15/2021'::date, 30, 'rabbit'),
    (8, 'd','01/15/2021'::date, 5, 'squirrel'),
    (8, 'd','01/15/2021'::date, 4, 'mod'),
    (9, 'e','02/10/2020'::date, 100, 'cat'),
    (9, 'e','02/10/2020'::date, 150, 'dog'),
    (9, 'e','02/10/2020'::date, 200, 'mouse'),
    (9, 'e','02/10/2020'::date, 300, 'rabbit'),
    (10, 'e','08/17/2020'::date, 50, 'mod'), -- floater
    (11, 'e','12/15/2020'::date, 40, 'mod'), -- floater
    (12, 'e','02/14/2021'::date, 40, 'mod'), -- floater
    (13, 'c','04/09/2022'::date, 0, 'mouse'),
    (13, 'c','04/09/2022'::date, 0, 'rabbit'),
    (13, 'c','04/09/2022'::date, 50, 'mod') -- floater because other values for transaction_id sum to 0
    ) as tbl (transaction_id, customer_id, transaction_date, amount, type)
),
previous_transaction_id as (
select 
    transaction_id,
    customer_id,
    lag(transaction_id, 1, null) over (partition by customer_id order by transaction_date) as previous_transaction_id
from data
   qualify transaction_id != previous_transaction_id
),

floating_mods as (
    select 
        transaction_id,
        sum(iff(type = 'mod', amount, 0)) as mod_amount,
        sum(amount) - mod_amount as non_mod_amount
    from data
    group by 1
    having non_mod_amount = 0
)

select 
    gp.transaction_id,
    gp.previous_transaction_id
from previous_transaction_id gp
    inner join floating_mods fm on gp.transaction_id = fm.transaction_id
order by gp.transaction_id

And here is the output of the query:

TRANSACTION_ID PREVIOUS_TRANSACTION_ID
2 1
4 3
6 5
7 6
10 9
11 10
12 11
13 7

And here is my desired output:

TRANSACTION_ID PREVIOUS_TRANSACTION_ID
2 1
4 3
6 5
7 5
10 9
11 9
12 9
13 5

Solution

  • I'm not sure this is captures the rules precisely, but it should be close enough to tweak if necessary. The LAG function should ignore nulls, and you can specify what for the purpose of the LAG function you consider a null using the IFF function or CASE WHEN block.

    with data as (
    select 
        transaction_id,
        customer_id,
        transaction_date,
        amount,
        type
    from (values 
        (1, 'a','03/24/2022'::date, 10, 'cat'),
        (1, 'a','03/24/2022'::date, 15, 'dog'),
        (1, 'a','03/24/2022'::date, 20, 'mouse'),
        (1, 'a','03/24/2022'::date, 30, 'rabbit'),
        (1, 'a','03/24/2022'::date, 5, 'squirrel'),
        (2, 'a','03/24/2022'::date, 4, 'mod'), -- floater
        (3, 'b','05/20/2022'::date, 100, 'cat'),
        (3, 'b','05/20/2022'::date, 150, 'dog'),
        (3, 'b','05/20/2022'::date, 200, 'mouse'),
        (3, 'b','05/20/2022'::date, 300, 'rabbit'),
        (3, 'b','05/20/2022'::date, 50, 'squirrel'),
        (4, 'b','07/20/2022'::date, 40, 'mod'), -- floater
        (5, 'c','02/02/2020'::date, 100, 'cat'),
        (5, 'c','02/02/2020'::date, 150, 'dog'),
        (5, 'c','02/02/2020'::date, 200, 'mouse'),
        (5, 'c','02/02/2020'::date, 300, 'rabbit'),
        (6, 'c','08/01/2020'::date, 50, 'mod'), -- floater
        (7, 'c','12/25/2020'::date, 40, 'mod'), -- floater
        (8, 'd','01/15/2021'::date, 10, 'cat'),
        (8, 'd','01/15/2021'::date, 15, 'dog'),
        (8, 'd','01/15/2021'::date, 20, 'mouse'),
        (8, 'd','01/15/2021'::date, 30, 'rabbit'),
        (8, 'd','01/15/2021'::date, 5, 'squirrel'),
        (8, 'd','01/15/2021'::date, 4, 'mod'),
        (9, 'e','02/10/2020'::date, 100, 'cat'),
        (9, 'e','02/10/2020'::date, 150, 'dog'),
        (9, 'e','02/10/2020'::date, 200, 'mouse'),
        (9, 'e','02/10/2020'::date, 300, 'rabbit'),
        (10, 'e','08/17/2020'::date, 50, 'mod'), -- floater
        (11, 'e','12/15/2020'::date, 40, 'mod'), -- floater
        (12, 'e','02/14/2021'::date, 40, 'mod'), -- floater
        (13, 'c','04/09/2022'::date, 0, 'mouse'),
        (13, 'c','04/09/2022'::date, 0, 'rabbit'),
        (13, 'c','04/09/2022'::date, 50, 'mod') -- floater because other values for transaction_id sum to 0
        ) as tbl (transaction_id, customer_id, transaction_date, amount, type)
    ), LAGGED as
    (
    select   TRANSACTION_ID
            ,TYPE
            ,lag(iff(type = 'mod' or AMOUNT = 0, null, TRANSACTION_ID)) ignore nulls over (order by TRANSACTION_ID) PREVIOUS_TRANSACTION_ID
    from data
    )
    select TRANSACTION_ID, PREVIOUS_TRANSACTION_ID 
    from LAGGED
    where type = 'mod' and TRANSACTION_ID <> PREVIOUS_TRANSACTION_ID
    ;
    
    TRANSACTION_ID PREVIOUS_TRANSACTION_ID
    2 1
    4 3
    6 5
    7 5
    10 9
    11 9
    12 9
    13 9

    Also, in the desired output the last row shows 5. Should that be 9?