I am migrating code from SAS to PySpark and I am struggling with the following SAS retain statement:
data ds;
set ds;
by group date;
retain Target;
if first.group then Target = Orig;
if first.group and ( Orig in (1,2,3,4,5) ) then Target = 6;
if not first.group and Target = 6 and (Orig in (1,2,3,4,5) ) then Target = 6 ;
if not first.group and ~(Target = 6 and (Orig in (1,2,3,4,5) ) ) then Target = Orig ;
run;
How can this be approached?
If the first Orig value is not 0, then round up to 6. If not first in group, and Target is already 6 and Orig is in 1,2,3,4,5 then keep target at 6. If not first in group, and either Target is not 6 or Orig is not in 1,2,3,4,5 then set Target to Orig value.
I am providing an example for a single group (apologies for the lengthy one):
df = SparkSession.createDataFrame([
(999, 5,6) ,
(999, 6,6) ,
(999, 4,6) ,
(999, 6,6) ,
(999, 3,6) ,
(999, 5,6) ,
(999, 4,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 5,6) ,
(999, 3,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 1,6) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 1,1) ,
(999, 2,2) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 4,4) ,
(999, 5,5) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 4,4) ,
(999, 5,5) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 4,6) ,
(999, 3,6) ,
(999, 2,6) ,
(999, 3,6) ,
(999, 4,6) ,
(999, 5,6) ,
(999, 6,6) ,
(999, 6,6) ],
['Group', 'Orig', 'Target']
)
there's no straightforward way of replicating SAS' retain
. but one can build on the logic. retain
, in a way, is looking at the previously calculated value to calculate the current value. so, you're, essentially, lagging the column you're creating, while you're creating it.
pyspark can do it using array of structs and higher order functions.
here's how to achieve it (note that I've added a date field - dt
- that'd help in data sorting - similar to your by group date;
statement)
# convert data to array of structs per group
arr_struct_sdf = data_sdf. \
withColumn('allattr', func.struct('dt', 'orig', 'exp_tgt')). \
groupBy('group'). \
agg(func.array_sort(func.collect_list('allattr')).alias('allattr')). \
withColumn('frst_elm', func.col('allattr')[0])
# +-----+--------------------+--------------------+
# |group| allattr| frst_elm|
# +-----+--------------------+--------------------+
# | 999|[{2020-01-01 00:0...|{2020-01-01 00:00...|
# +-----+--------------------+--------------------+
# use `aggregate` higher order function to generate `target` field
arr_struct_sdf. \
withColumn('new_attr',
func.aggregate(func.expr('slice(allattr, 2, size(allattr))'),
func.array(func.col('frst_elm').withField('tgt',
func.when(func.col('frst_elm.orig').isin(1,2,3,4,5), func.lit(6)).
otherwise(func.col('frst_elm.orig'))
)
),
lambda x, y: func.array_union(x,
func.array(y.withField('tgt',
func.when((func.element_at(x, -1).tgt == 6) & (y.orig.isin(1,2,3,4,5)), func.lit(6)).
otherwise(y.orig)
)
)
)
)
). \
selectExpr('group', 'inline(new_attr)'). \
show(100, False)
# +-----+-------------------+----+-------+---+
# |group|dt |orig|exp_tgt|tgt|
# +-----+-------------------+----+-------+---+
# |999 |2020-01-01 00:00:00|5 |6 |6 |
# |999 |2020-01-02 00:00:00|6 |6 |6 |
# |999 |2020-01-03 00:00:00|4 |6 |6 |
# |999 |2020-01-04 00:00:00|6 |6 |6 |
# |999 |2020-01-05 00:00:00|3 |6 |6 |
# |999 |2020-01-06 00:00:00|5 |6 |6 |
# |999 |2020-01-07 00:00:00|4 |6 |6 |
# |999 |2020-01-08 00:00:00|6 |6 |6 |
# |999 |2020-01-09 00:00:00|6 |6 |6 |
# |999 |2020-01-10 00:00:00|6 |6 |6 |
# |999 |2020-01-11 00:00:00|6 |6 |6 |
# |999 |2020-01-12 00:00:00|5 |6 |6 |
# |999 |2020-01-13 00:00:00|3 |6 |6 |
# |999 |2020-01-14 00:00:00|2 |6 |6 |
# |999 |2020-01-15 00:00:00|2 |6 |6 |
# |999 |2020-01-16 00:00:00|2 |6 |6 |
# |999 |2020-01-17 00:00:00|2 |6 |6 |
# |999 |2020-01-18 00:00:00|2 |6 |6 |
# |999 |2020-01-19 00:00:00|2 |6 |6 |
# |999 |2020-01-20 00:00:00|2 |6 |6 |
# |999 |2020-01-21 00:00:00|2 |6 |6 |
# |999 |2020-01-22 00:00:00|1 |6 |6 |
# |999 |2020-01-23 00:00:00|0 |0 |0 |
# |999 |2020-01-24 00:00:00|0 |0 |0 |
# |999 |2020-01-25 00:00:00|0 |0 |0 |
# |999 |2020-01-26 00:00:00|0 |0 |0 |
# |999 |2020-01-27 00:00:00|1 |1 |1 |
# |999 |2020-01-28 00:00:00|1 |1 |1 |
# |999 |2020-01-29 00:00:00|2 |2 |2 |
# |999 |2020-01-30 00:00:00|2 |2 |2 |
# |999 |2020-01-31 00:00:00|3 |3 |3 |
# |999 |2020-02-01 00:00:00|2 |2 |2 |
# |999 |2020-02-02 00:00:00|3 |3 |3 |
# |999 |2020-02-03 00:00:00|4 |4 |4 |
# |999 |2020-02-04 00:00:00|5 |5 |5 |
# |999 |2020-02-05 00:00:00|6 |6 |6 |
# |999 |2020-02-06 00:00:00|6 |6 |6 |
# |999 |2020-02-07 00:00:00|6 |6 |6 |
# |999 |2020-02-08 00:00:00|0 |0 |0 |
# |999 |2020-02-09 00:00:00|1 |1 |1 |
# |999 |2020-02-10 00:00:00|0 |0 |0 |
# |999 |2020-02-11 00:00:00|1 |1 |1 |
# |999 |2020-02-12 00:00:00|2 |2 |2 |
# |999 |2020-02-13 00:00:00|3 |3 |3 |
# |999 |2020-02-14 00:00:00|4 |4 |4 |
# |999 |2020-02-15 00:00:00|5 |5 |5 |
# |999 |2020-02-16 00:00:00|6 |6 |6 |
# |999 |2020-02-17 00:00:00|6 |6 |6 |
# |999 |2020-02-18 00:00:00|6 |6 |6 |
# |999 |2020-02-19 00:00:00|6 |6 |6 |
# |999 |2020-02-20 00:00:00|4 |6 |6 |
# |999 |2020-02-21 00:00:00|3 |6 |6 |
# |999 |2020-02-22 00:00:00|2 |6 |6 |
# |999 |2020-02-23 00:00:00|3 |6 |6 |
# |999 |2020-02-24 00:00:00|4 |6 |6 |
# |999 |2020-02-25 00:00:00|5 |6 |6 |
# |999 |2020-02-26 00:00:00|6 |6 |6 |
# |999 |2020-02-27 00:00:00|6 |6 |6 |
# +-----+-------------------+----+-------+---+
explanation
aggregate
higher order function takes an array, the initial value and a function to merge (like python's reduce
).
target
(which is the first.group
calculation).not first.group
conditions are calculatedy
is the currently being calculated value, which looks at the previously calculated value which, in turn, is the last value from x
(element_at(x, -1)
)P.S. exp_tgt
(expected target) is the target
field already in your sample data shared in the question. tgt
is the final target field that pyspark generated.
for those who're unable to use aggregate
function due to older spark versions, they can use the aggregate
SQL function within expr()
like below.
data_sdf. \
withColumn('allattr', func.struct('dt', 'orig', 'exp_tgt')). \
groupBy('group'). \
agg(func.array_sort(func.collect_list('allattr')).alias('allattr')). \
withColumn('frst_elm', func.col('allattr')[0]). \
withColumn('new_attr',
func.expr('''
aggregate(slice(allattr, 2, size(allattr)),
array(struct(frst_elm.dt as dt, frst_elm.orig as orig, frst_elm.exp_tgt as exp_tgt, if(frst_elm.orig in (1,2,3,4,5), 6, frst_elm.orig) as tgt)),
(x, y) -> array_union(x,
array(struct(y.dt as dt, y.orig as orig, y.exp_tgt as exp_tgt, if(element_at(x, -1).tgt=6 and y.orig in (1,2,3,4,5), 6, y.orig) as tgt))
)
)
''')
). \
selectExpr('group', 'inline(new_attr)'). \
show(100, truncate=False)