Search code examples
pythonpandasdataframeapache-sparkpyspark

create a date range if a column value matches one


I am using an answer found at iterate over select columns and check if a specfic value is in these select columns and use that column name that has that value to create a new table

we can use pyspark native functions to create an array of the column names that have the value 1. the array can then be used to get the min and max of years but I want to create a new row if 1 comes after a 0.

here's an example input table



# +---+-----+---+-----+-----+-----+-----+-----+-----+
# |  a|    b| id|m2000|m2001|m2002|m2003|m2004|m2005|
# +---+-----+---+-----+-----+-----+-----+-----+-----+
# |  a|world|  1|    0|    1|    1|    0|    0|    1|
# |  b|world|  2|    0|    1|    1|    1|    1|    1|
# |  c|world|  3|    1|    1|    0|    0|    1|    1|
# +---+-----+---+-----+-----+-----+-----+-----+-----+

I want the final table to be like:

# +---+-----+---+--------+--------+
# |  a|    b| id|startdate|enddate|
# +---+-----+---+--------+---------
# |  a|world|  1|    2001|    2002| 
# |  a|world|  1|    2005|    2005|  
# |  b|world|  2|    2001|    2005|    
# |  c|world|  3|    2000|    2001|    
# |  c|world|  3|    2004|    2005|
# +---+-----+---+-----+-----+-----+

python
data_ls = [
    ("a", "world", "1", 0, 0, 1,0,0,1),
    ("b", "world", "2", 0, 1, 0,1,0,1),
    ("c", "world", "3", 0, 0, 0,0,0,0)
]

data_sdf = spark.sparkContext.parallelize(data_ls). \
    toDF(['a', 'b', 'id', 'm2000', 'm2001', 'm2002', 'm2003', 'm2004', 'm2005'])


yearcols = [k for k in data_sdf.columns if k.startswith('m20')]

data_sdf. \
    withColumn('yearcol_structs', 
               func.array(*[func.struct(func.lit(int(c[-4:])).alias('year'), func.col(c).alias('value')) 
                            for c in yearcols]
                          )
               ). \
    withColumn('yearcol_1s', 
               func.expr('transform(filter(yearcol_structs, x -> x.value = 1), f -> f.year)')
               ). \
    filter(func.size('yearcol_1s') >= 1). \
    withColumn('year_start', func.concat(func.lit('10/10/'), func.array_min('yearcol_1s'))). \
    withColumn('year_end', func.concat(func.lit('10/10/'), func.array_max('yearcol_1s'))). \
    show(truncate=False)





Solution

  • Step by step solution

    Stack the dataframe to reshape into long format

    # Primary id columns
    keys = ['a', 'b', 'id']
    
    # Create a dynamic stack expression 
    stackexpr = f"stack({len(yearcols)}, %s) as (year, val)" \
                % ', '.join(f'"{c[-4:]}", {c}' for c in yearcols)
    
    df = data_sdf.selectExpr(*keys, stackexpr)
    df.show()
    
    # +---+-----+---+----+---+
    # |  a|    b| id|year|val|
    # +---+-----+---+----+---+
    # |  a|world|  1|2000|  0|
    # |  a|world|  1|2001|  1|
    # |  a|world|  1|2002|  1|
    # |  a|world|  1|2003|  0|
    # |  a|world|  1|2004|  0|
    # |  a|world|  1|2005|  1|
    # |  b|world|  2|2000|  0|
    # |  b|world|  2|2001|  1|
    # |  b|world|  2|2002|  1|
    # |  b|world|  2|2003|  1|
    # |  b|world|  2|2004|  1|
    # |  b|world|  2|2005|  1|
    # |  c|world|  3|2000|  1|
    # |  c|world|  3|2001|  1|
    # |  c|world|  3|2002|  0|
    # |  c|world|  3|2003|  0|
    # |  c|world|  3|2004|  1|
    # |  c|world|  3|2005|  1|
    # +---+-----+---+----+---+
    

    Create a supplementary group key to identify blocks of rows having consecutive ones's. Then filter/remove the rows with zeros since we only consider one's to find start and end date

    m = F.col('val') == 0
    W = Window.partitionBy(*keys).orderBy('year')
    df = df.withColumn('blocks', F.sum(m.cast('int')).over(W)).filter(~m)
    df.show()
    
    # +---+-----+---+----+---+------+
    # |  a|    b| id|year|val|blocks|
    # +---+-----+---+----+---+------+
    # |  a|world|  1|2001|  1|     1|
    # |  a|world|  1|2002|  1|     1|
    # |  a|world|  1|2005|  1|     3|
    # |  b|world|  2|2001|  1|     1|
    # |  b|world|  2|2002|  1|     1|
    # |  b|world|  2|2003|  1|     1|
    # |  b|world|  2|2004|  1|     1|
    # |  b|world|  2|2005|  1|     1|
    # |  c|world|  3|2000|  1|     0|
    # |  c|world|  3|2001|  1|     0|
    # |  c|world|  3|2004|  1|     2|
    # |  c|world|  3|2005|  1|     2|
    # +---+-----+---+----+---+------+
    

    Group the dataframe by keys along with blocks and aggregate year with min and max

    df = (
        df
        .groupBy(*keys, 'blocks')
        .agg(F.min('year').alias('start'), F.max('year').alias('end'))
        .drop('blocks')
    )
    
    df.show()
    
    # +---+-----+---+-----+----+
    # |  a|    b| id|start| end|
    # +---+-----+---+-----+----+
    # |  a|world|  1| 2001|2002|
    # |  a|world|  1| 2005|2005|
    # |  b|world|  2| 2001|2005|
    # |  c|world|  3| 2000|2001|
    # |  c|world|  3| 2004|2005|
    # +---+-----+---+-----+----+