I am using an answer found at iterate over select columns and check if a specfic value is in these select columns and use that column name that has that value to create a new table
we can use pyspark native functions to create an array of the column names that have the value 1
. the array can then be used to get the min
and max
of years but I want to create a new row if 1 comes after a 0.
here's an example input table
# +---+-----+---+-----+-----+-----+-----+-----+-----+
# | a| b| id|m2000|m2001|m2002|m2003|m2004|m2005|
# +---+-----+---+-----+-----+-----+-----+-----+-----+
# | a|world| 1| 0| 1| 1| 0| 0| 1|
# | b|world| 2| 0| 1| 1| 1| 1| 1|
# | c|world| 3| 1| 1| 0| 0| 1| 1|
# +---+-----+---+-----+-----+-----+-----+-----+-----+
I want the final table to be like:
# +---+-----+---+--------+--------+
# | a| b| id|startdate|enddate|
# +---+-----+---+--------+---------
# | a|world| 1| 2001| 2002|
# | a|world| 1| 2005| 2005|
# | b|world| 2| 2001| 2005|
# | c|world| 3| 2000| 2001|
# | c|world| 3| 2004| 2005|
# +---+-----+---+-----+-----+-----+
python
data_ls = [
("a", "world", "1", 0, 0, 1,0,0,1),
("b", "world", "2", 0, 1, 0,1,0,1),
("c", "world", "3", 0, 0, 0,0,0,0)
]
data_sdf = spark.sparkContext.parallelize(data_ls). \
toDF(['a', 'b', 'id', 'm2000', 'm2001', 'm2002', 'm2003', 'm2004', 'm2005'])
yearcols = [k for k in data_sdf.columns if k.startswith('m20')]
data_sdf. \
withColumn('yearcol_structs',
func.array(*[func.struct(func.lit(int(c[-4:])).alias('year'), func.col(c).alias('value'))
for c in yearcols]
)
). \
withColumn('yearcol_1s',
func.expr('transform(filter(yearcol_structs, x -> x.value = 1), f -> f.year)')
). \
filter(func.size('yearcol_1s') >= 1). \
withColumn('year_start', func.concat(func.lit('10/10/'), func.array_min('yearcol_1s'))). \
withColumn('year_end', func.concat(func.lit('10/10/'), func.array_max('yearcol_1s'))). \
show(truncate=False)
Stack the dataframe to reshape into long format
# Primary id columns
keys = ['a', 'b', 'id']
# Create a dynamic stack expression
stackexpr = f"stack({len(yearcols)}, %s) as (year, val)" \
% ', '.join(f'"{c[-4:]}", {c}' for c in yearcols)
df = data_sdf.selectExpr(*keys, stackexpr)
df.show()
# +---+-----+---+----+---+
# | a| b| id|year|val|
# +---+-----+---+----+---+
# | a|world| 1|2000| 0|
# | a|world| 1|2001| 1|
# | a|world| 1|2002| 1|
# | a|world| 1|2003| 0|
# | a|world| 1|2004| 0|
# | a|world| 1|2005| 1|
# | b|world| 2|2000| 0|
# | b|world| 2|2001| 1|
# | b|world| 2|2002| 1|
# | b|world| 2|2003| 1|
# | b|world| 2|2004| 1|
# | b|world| 2|2005| 1|
# | c|world| 3|2000| 1|
# | c|world| 3|2001| 1|
# | c|world| 3|2002| 0|
# | c|world| 3|2003| 0|
# | c|world| 3|2004| 1|
# | c|world| 3|2005| 1|
# +---+-----+---+----+---+
Create a supplementary group key to identify blocks of rows having consecutive ones's. Then filter/remove the rows with zeros since we only consider one's to find start and end date
m = F.col('val') == 0
W = Window.partitionBy(*keys).orderBy('year')
df = df.withColumn('blocks', F.sum(m.cast('int')).over(W)).filter(~m)
df.show()
# +---+-----+---+----+---+------+
# | a| b| id|year|val|blocks|
# +---+-----+---+----+---+------+
# | a|world| 1|2001| 1| 1|
# | a|world| 1|2002| 1| 1|
# | a|world| 1|2005| 1| 3|
# | b|world| 2|2001| 1| 1|
# | b|world| 2|2002| 1| 1|
# | b|world| 2|2003| 1| 1|
# | b|world| 2|2004| 1| 1|
# | b|world| 2|2005| 1| 1|
# | c|world| 3|2000| 1| 0|
# | c|world| 3|2001| 1| 0|
# | c|world| 3|2004| 1| 2|
# | c|world| 3|2005| 1| 2|
# +---+-----+---+----+---+------+
Group the dataframe by keys
along with blocks
and aggregate year
with min
and max
df = (
df
.groupBy(*keys, 'blocks')
.agg(F.min('year').alias('start'), F.max('year').alias('end'))
.drop('blocks')
)
df.show()
# +---+-----+---+-----+----+
# | a| b| id|start| end|
# +---+-----+---+-----+----+
# | a|world| 1| 2001|2002|
# | a|world| 1| 2005|2005|
# | b|world| 2| 2001|2005|
# | c|world| 3| 2000|2001|
# | c|world| 3| 2004|2005|
# +---+-----+---+-----+----+