Search code examples
sqlpysparkapache-spark-sql

Survivorship rules in SQL


I have an input dataset like below and I want to apply rules on 'supplierid'.

 1. If any of the rows have 'Cond A', the supplier ID should be mapped to 'Cond A'.

 2. If one of the supplier id rows have 'A' and the other 'not A', then it should be mapped to 'A'

 3. If both rows have same elig, then dedup it and use 1.

I am able to resolve the last condition but I am struggling to apply the first 2 rules.

Input dataset:

supplierid,elig,source  
1,A,source1  
1,Not A,source2  
2,A,source1  
2,Cond A,source2  
3,Not A,source1  
3,Cond A,source2  
4,Cond A,source1  
4,Cond A,source2  
5,Not A,source1  
5,Not A,source2  

Output dataset should be:

supplierid,elig  
1,A  
2,Cond A  
3,Cond A  
4,Cond A  
5,Not A

I tried the below but it does not work as expected:

SELECT  
  supplier_id,  
  MAX(CASE  
    WHEN elig IN ('A', 'Not A') THEN 'A'  
    WHEN elig IN ('A', 'Cond A') THEN 'Cond A'  
    WHEN elig IN ('Not A', 'Cond A') THEN 'Cond A'  
    ELSE 'Not A'  
  END) OVER (PARTITION BY supplier_id) AS elig  
FROM poc_demo;  

Solution

  • I used collect_set to get array of elig, then used array_contains function to check for values and applied rule in the case:

    val dF1 = Seq(
    (1,"A"     ,"source1" ), 
    (1,"Not A" ,"source2" ), 
    (2,"A"     ,"source1" ), 
    (2,"Cond A","source2" ), 
    (3,"Not A" ,"source1" ), 
    (3,"Cond A","source2" ), 
    (4,"Cond A","source1" ), 
    (4,"Cond A","source2" ), 
    (5,"Not A" ,"source1" ), 
    (5,"Not A" ,"source2" )  
    ).toDF("supplierid","elig","source")
    
    dF1.createOrReplaceTempView("poc_demo")
    
    
    spark.sql(""" 
    select supplierid, case when array_contains(eligs,'Cond A') then 'Cond A'
                            when array_contains(eligs,'A') and array_contains(eligs,'Not A') then 'A'
                            else elig
                        end elig
    from                    
        (select supplierid, elig, collect_set(elig) over(partition by supplierid) eligs
        from poc_demo d
        )s
    group by 1,2
    """).show(100, false)
    

    Result:

    +----------+------+
    |supplierid|elig  |
    +----------+------+
    |1         |A     |
    |2         |Cond A|
    |3         |Cond A|
    |4         |Cond A|
    |5         |Not A |
    +----------+------+
    

    The upper level code is in Scala, but in pyspark the same sql will work fine.