Search code examples
apache-sparkpysparksas

Replicating a SAS Retain Statement in PySpark


I am migrating code from SAS to PySpark and I am struggling with the following SAS retain statement:

data ds;
set  ds;
 
    by group date;
   
    retain Target; 
 
    if first.group                             then Target = Orig;
    if first.group and ( Orig in (1,2,3,4,5) ) then Target = 6;
 
    if not first.group and   Target  = 6 and (Orig in (1,2,3,4,5) )   then Target = 6 ;
    if not first.group and ~(Target  = 6 and (Orig in (1,2,3,4,5) ) ) then Target = Orig ;
   
    run;

How can this be approached?

If the first Orig value is not 0, then round up to 6. If not first in group, and Target is already 6 and Orig is in 1,2,3,4,5 then keep target at 6. If not first in group, and either Target is not 6 or Orig is not in 1,2,3,4,5 then set Target to Orig value.

I am providing an example for a single group (apologies for the lengthy one):

df = SparkSession.createDataFrame([
(999, 5,6) ,
(999, 6,6) ,
(999, 4,6) ,
(999, 6,6) ,
(999, 3,6) ,
(999, 5,6) ,
(999, 4,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 5,6) ,
(999, 3,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 1,6) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 1,1) ,
(999, 2,2) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 4,4) ,
(999, 5,5) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 4,4) ,
(999, 5,5) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 4,6) ,
(999, 3,6) ,
(999, 2,6) ,
(999, 3,6) ,
(999, 4,6) ,
(999, 5,6) ,
(999, 6,6) ,
(999, 6,6) ],
['Group', 'Orig', 'Target']
)

Solution

  • there's no straightforward way of replicating SAS' retain. but one can build on the logic. retain, in a way, is looking at the previously calculated value to calculate the current value. so, you're, essentially, lagging the column you're creating, while you're creating it.

    pyspark can do it using array of structs and higher order functions.

    here's how to achieve it (note that I've added a date field - dt - that'd help in data sorting - similar to your by group date; statement)

    # convert data to array of structs per group
    arr_struct_sdf = data_sdf. \
        withColumn('allattr', func.struct('dt', 'orig', 'exp_tgt')). \
        groupBy('group'). \
        agg(func.array_sort(func.collect_list('allattr')).alias('allattr')). \
        withColumn('frst_elm', func.col('allattr')[0])
    
    # +-----+--------------------+--------------------+
    # |group|             allattr|            frst_elm|
    # +-----+--------------------+--------------------+
    # |  999|[{2020-01-01 00:0...|{2020-01-01 00:00...|
    # +-----+--------------------+--------------------+
    
    # use `aggregate` higher order function to generate `target` field
    arr_struct_sdf. \
        withColumn('new_attr',
                   func.aggregate(func.expr('slice(allattr, 2, size(allattr))'),
                                  func.array(func.col('frst_elm').withField('tgt', 
                                                                            func.when(func.col('frst_elm.orig').isin(1,2,3,4,5), func.lit(6)).
                                                                            otherwise(func.col('frst_elm.orig'))
                                                                            )
                                             ),
                                  lambda x, y: func.array_union(x,
                                                                func.array(y.withField('tgt',
                                                                                       func.when((func.element_at(x, -1).tgt == 6) & (y.orig.isin(1,2,3,4,5)), func.lit(6)).
                                                                                       otherwise(y.orig)
                                                                                       )
                                                                           )
                                                                )
                                  )
                   ). \
        selectExpr('group', 'inline(new_attr)'). \
        show(100, False)
    
    # +-----+-------------------+----+-------+---+
    # |group|dt                 |orig|exp_tgt|tgt|
    # +-----+-------------------+----+-------+---+
    # |999  |2020-01-01 00:00:00|5   |6      |6  |
    # |999  |2020-01-02 00:00:00|6   |6      |6  |
    # |999  |2020-01-03 00:00:00|4   |6      |6  |
    # |999  |2020-01-04 00:00:00|6   |6      |6  |
    # |999  |2020-01-05 00:00:00|3   |6      |6  |
    # |999  |2020-01-06 00:00:00|5   |6      |6  |
    # |999  |2020-01-07 00:00:00|4   |6      |6  |
    # |999  |2020-01-08 00:00:00|6   |6      |6  |
    # |999  |2020-01-09 00:00:00|6   |6      |6  |
    # |999  |2020-01-10 00:00:00|6   |6      |6  |
    # |999  |2020-01-11 00:00:00|6   |6      |6  |
    # |999  |2020-01-12 00:00:00|5   |6      |6  |
    # |999  |2020-01-13 00:00:00|3   |6      |6  |
    # |999  |2020-01-14 00:00:00|2   |6      |6  |
    # |999  |2020-01-15 00:00:00|2   |6      |6  |
    # |999  |2020-01-16 00:00:00|2   |6      |6  |
    # |999  |2020-01-17 00:00:00|2   |6      |6  |
    # |999  |2020-01-18 00:00:00|2   |6      |6  |
    # |999  |2020-01-19 00:00:00|2   |6      |6  |
    # |999  |2020-01-20 00:00:00|2   |6      |6  |
    # |999  |2020-01-21 00:00:00|2   |6      |6  |
    # |999  |2020-01-22 00:00:00|1   |6      |6  |
    # |999  |2020-01-23 00:00:00|0   |0      |0  |
    # |999  |2020-01-24 00:00:00|0   |0      |0  |
    # |999  |2020-01-25 00:00:00|0   |0      |0  |
    # |999  |2020-01-26 00:00:00|0   |0      |0  |
    # |999  |2020-01-27 00:00:00|1   |1      |1  |
    # |999  |2020-01-28 00:00:00|1   |1      |1  |
    # |999  |2020-01-29 00:00:00|2   |2      |2  |
    # |999  |2020-01-30 00:00:00|2   |2      |2  |
    # |999  |2020-01-31 00:00:00|3   |3      |3  |
    # |999  |2020-02-01 00:00:00|2   |2      |2  |
    # |999  |2020-02-02 00:00:00|3   |3      |3  |
    # |999  |2020-02-03 00:00:00|4   |4      |4  |
    # |999  |2020-02-04 00:00:00|5   |5      |5  |
    # |999  |2020-02-05 00:00:00|6   |6      |6  |
    # |999  |2020-02-06 00:00:00|6   |6      |6  |
    # |999  |2020-02-07 00:00:00|6   |6      |6  |
    # |999  |2020-02-08 00:00:00|0   |0      |0  |
    # |999  |2020-02-09 00:00:00|1   |1      |1  |
    # |999  |2020-02-10 00:00:00|0   |0      |0  |
    # |999  |2020-02-11 00:00:00|1   |1      |1  |
    # |999  |2020-02-12 00:00:00|2   |2      |2  |
    # |999  |2020-02-13 00:00:00|3   |3      |3  |
    # |999  |2020-02-14 00:00:00|4   |4      |4  |
    # |999  |2020-02-15 00:00:00|5   |5      |5  |
    # |999  |2020-02-16 00:00:00|6   |6      |6  |
    # |999  |2020-02-17 00:00:00|6   |6      |6  |
    # |999  |2020-02-18 00:00:00|6   |6      |6  |
    # |999  |2020-02-19 00:00:00|6   |6      |6  |
    # |999  |2020-02-20 00:00:00|4   |6      |6  |
    # |999  |2020-02-21 00:00:00|3   |6      |6  |
    # |999  |2020-02-22 00:00:00|2   |6      |6  |
    # |999  |2020-02-23 00:00:00|3   |6      |6  |
    # |999  |2020-02-24 00:00:00|4   |6      |6  |
    # |999  |2020-02-25 00:00:00|5   |6      |6  |
    # |999  |2020-02-26 00:00:00|6   |6      |6  |
    # |999  |2020-02-27 00:00:00|6   |6      |6  |
    # +-----+-------------------+----+-------+---+
    

    explanation

    aggregate higher order function takes an array, the initial value and a function to merge (like python's reduce).

    • in this case, I passed the source array of structs containing each row as a struct except for the first row of the group.
    • the second parameter is the initial value which is the first row of the group along with it's calculated target (which is the first.group calculation).
    • the third parameter is the merge function which takes the initial value, and merges the other values recursively
      • this is where the not first.group conditions are calculated
      • the y is the currently being calculated value, which looks at the previously calculated value which, in turn, is the last value from x (element_at(x, -1))

    P.S. exp_tgt (expected target) is the target field already in your sample data shared in the question. tgt is the final target field that pyspark generated.


    for those who're unable to use aggregate function due to older spark versions, they can use the aggregate SQL function within expr() like below.

    data_sdf. \
        withColumn('allattr', func.struct('dt', 'orig', 'exp_tgt')). \
        groupBy('group'). \
        agg(func.array_sort(func.collect_list('allattr')).alias('allattr')). \
        withColumn('frst_elm', func.col('allattr')[0]). \
        withColumn('new_attr', 
                   func.expr('''
                        aggregate(slice(allattr, 2, size(allattr)), 
                                  array(struct(frst_elm.dt as dt, frst_elm.orig as orig, frst_elm.exp_tgt as exp_tgt, if(frst_elm.orig in (1,2,3,4,5), 6, frst_elm.orig) as tgt)),
                                  (x, y) -> array_union(x,
                                                        array(struct(y.dt as dt, y.orig as orig, y.exp_tgt as exp_tgt, if(element_at(x, -1).tgt=6 and y.orig in (1,2,3,4,5), 6, y.orig) as tgt))
                                                        )
                                  )
                   ''')
                   ). \
        selectExpr('group', 'inline(new_attr)'). \
        show(100, truncate=False)