Search code examples
pandaspysparkwindow

multiplying group of columns for each unique variant in a column and fill all rows of the columns with that value


I have a pysark DataFrame looking like that:

df = spark.createDataFrame(
    [(0, 'foo', '2020-01-01', '2020-02-01'),
     (0, 'bar', '2020-02-01', '2020-03-01'),
     (0, 'foo', '2020-03-01', '2020-04-01'),
     (0, None, '2020-04-01', '2020-05-01'),
     (1, 'bar', '2020-01-01', '2020-02-01'),
     (1, 'foo', '2020-02-01', '2020-03-01'),
     (2, None, '2020-02-01', '2020-03-01'),
     (2, None, '2020-04-01', '2020-07-01')
     ],
    ['group', 'value', 'start', 'end'])

df.show()
Out[1]:
group value start      end 
0     foo   2020-01-01 2020-02-01   
0     bar   2020-02-01 2020-03-01  
0     foo   2020-03-01 2020-04-01
0     None  2020-04-01 2020-05-01  
1     bar   2020-01-01 2020-02-01  
1     foo   2020-02-01 2020-03-01
2     None  2020-02-01 2020-03-01
2     None  2020-04-01 2020-07-01  

I would like to add rows for each variant of column variant within each group as of col group and than fill up each additional row with that variant. As @samkart mentioned as there are 4 zeroes in group, there should be 4 foo and 4 bar values within the 0 group. None values should not be counted as additional variants but groups with None values only should keep None as value so that the result looks like that:

group value start      end  
0     foo   2020-01-01 2020-02-01  
0     foo   2020-02-01 2020-03-01     
0     foo   2020-03-01 2020-04-01 
0     foo   2020-04-01 2020-05-01
0     bar   2020-01-01 2020-02-01   
0     bar   2020-02-01 2020-03-01  
0     bar   2020-03-01 2020-04-01
0     bar   2020-04-01 2020-05-01
1     bar   2020-01-01 2020-02-01  
1     bar   2020-02-01 2020-03-01
1     foo   2020-01-01 2020-02-01  
1     foo   2020-02-01 2020-03-01
2     None  2020-02-01 2020-03-01
2     None  2020-04-01 2020-07-01

I experimented with counting the variants and than exploding the rows with

df = df.withColumn("n",func.expr("explode(array_repeat(n,int(n)))"),)

but I can't figure out a way to fill the variant values in the desired way


Solution

  • You're close. Here's a working example using your input data.

    data_sdf. \
        withColumn('group_count', 
                   func.count('group').over(wd.partitionBy('group')).cast('int')
                   ). \
        filter(func.col('value').isNotNull()). \
        dropDuplicates(). \
        withColumn('new_val_arr', func.expr('array_repeat(value, group_count)')). \
        selectExpr('group', 'explode(new_val_arr) as value'). \
        show()
    
    # +-----+-----+
    # |group|value|
    # +-----+-----+
    # |    0|  foo|
    # |    0|  foo|
    # |    0|  foo|
    # |    0|  foo|
    # |    0|  bar|
    # |    0|  bar|
    # |    0|  bar|
    # |    0|  bar|
    # |    1|  bar|
    # |    1|  bar|
    # |    1|  foo|
    # |    1|  foo|
    # +-----+-----+
    

    EDIT - As the question was updated to include null values as is for groups where all values are null.

    Two ways to do.

    Filter out the nulls and again append records for groups with all null values

    data2_sdf = data_sdf. \
        withColumn('group_count', 
                   func.count('group').over(wd.partitionBy('group')).cast('int')
                   ). \
        withColumn('null_count',
                   func.sum(func.col('value').isNull().cast('int')).over(wd.partitionBy('group'))
                   )
    
    data2_sdf. \
        filter(func.col('group_count') != func.col('null_count')). \
        filter(func.col('value').isNotNull()). \
        dropDuplicates(). \
        withColumn('new_val_arr', func.expr('array_repeat(value, group_count)')). \
        selectExpr('group', 'explode(new_val_arr) as value'). \
        unionByName(data2_sdf.
                    filter(func.col('group_count') == func.col('null_count')).
                    select('group', 'value')
                    ). \
        show()
    
    # +-----+-----+
    # |group|value|
    # +-----+-----+
    # |    0|  foo|
    # |    0|  foo|
    # |    0|  foo|
    # |    0|  foo|
    # |    0|  bar|
    # |    0|  bar|
    # |    0|  bar|
    # |    0|  bar|
    # |    1|  bar|
    # |    1|  bar|
    # |    1|  foo|
    # |    1|  foo|
    # |    2| null|
    # |    2| null|
    # +-----+-----+
    

    Or, create an array of unique values and explode it

    data_sdf. \
        withColumn('group_count', 
                   func.count('group').over(wd.partitionBy('group')).cast('int')
                   ). \
        withColumn('null_count',
                   func.sum(func.col('value').isNull().cast('int')).over(wd.partitionBy('group'))
                   ). \
        filter(func.col('value').isNotNull() | (func.col('group_count') == func.col('null_count'))). \
        groupBy('group', 'group_count'). \
        agg(func.collect_set(func.coalesce('value', func.lit('null'))).alias('val_set')). \
        withColumn('new_val_arr', func.expr('flatten(array_repeat(val_set, group_count))')). \
        selectExpr('group', 'explode(new_val_arr) as value'). \
        withColumn('value', func.when(func.col('value') != 'null', func.col('value'))). \
        show()
    
    # +-----+-----+
    # |group|value|
    # +-----+-----+
    # |    0|  bar|
    # |    0|  foo|
    # |    0|  bar|
    # |    0|  foo|
    # |    0|  bar|
    # |    0|  foo|
    # |    0|  bar|
    # |    0|  foo|
    # |    1|  bar|
    # |    1|  foo|
    # |    1|  bar|
    # |    1|  foo|
    # |    2| null|
    # |    2| null|
    # +-----+-----+