Search code examples
pysparksubsetpermutation

How can I cross a pyspark subsets of a dataframe with two columns of another dataframe?


Problem

How can I mark in a subset of a pyspark dataframe based on another dataframe?

In Table A the subset is defined by a code in the column code. For each code, we have multiple registers. We have one column, servico. In Table B we have the same columns cod_1 and cod_2 and the column we want to cross with Table A. In case the cod_1 and cod_2are present in the subset defined by code in the Table A column servico, we fill the Table A column rule in the register with the cod_2 with the values in Table B rule.

Requirements

The only restriction is that everything needs to be done in pyspark or SQL without collect(). I want to maximize performance in a Databricks server.

Example

Table A is defined by:

table_a= spark.createDataFrame([
        ('123', 'A1',''), 
        ('123', 'E1',''),
        ('123', 'A3',''),
        ('123', 'B1',''),
        ('123', 'B2',''),
        ('123', 'B3',''),
        ('321', 'C1',''),
        ('321', 'C2',''),
        ('321', 'C3',''),
        ('321', 'C4',''),
        ('321', 'D1',''),
    ],
    ['code', 'servico', 'rule'] 
)

Table B is defined by:

table_b = spark.createDataFrame([
        ('E1', 'A1','aaa'), 
        ('E2', 'A2','bbb'),
        ('F1', 'A3','ccc'),
        ('F2', 'B1','ddd'),
        ('F3', 'B2','eee'),
        ('G1', 'B3','fff'),
        ('G2', 'C1','ggg'),
        ('G3', 'C2','hhh'),
        ('G4', 'C3','iii'),
        ('H1', 'C4','jjj'),
        ('H2', 'D1','kkk'),
    ],
    ['cod_1', 'cod_2', 'rule'] 
)

The expected result is:

result = spark.createDataFrame([
        ('123', 'A1','aaa'), 
        ('123', 'E1',''),
        ('123', 'A3',''),
        ('123', 'B1',''),
        ('123', 'B2',''),
        ('123', 'B3',''),
        ('321', 'C1',''),
        ('321', 'C2',''),
        ('321', 'C3',''),
        ('321', 'C4',''),
        ('321', 'D1',''),
    ],
    ['code', 'servico', 'rule'] 
)

The only register that is filled with the Table B is the first because we find both A1 and E1 for the code 123.


Solution

  • In order to prevent bulky dataframe (i.e table a join table a to get all the combination) and handle the key ordering, I collect all the servico of each code first:

    table_a_agg = table_a\
        .withColumn('set_of_servico', func.collect_set('servico').over(Window.partitionBy('code')).cast(types.StringType()))
    table_a_agg.show(100, False)
    +----+-------+----+------------------------+
    |code|servico|rule|set_of_servico          |
    +----+-------+----+------------------------+
    |123 |A1     |    |[B1, A3, A1, B2, E1, B3]|
    |123 |E1     |    |[B1, A3, A1, B2, E1, B3]|
    |123 |A3     |    |[B1, A3, A1, B2, E1, B3]|
    |123 |B1     |    |[B1, A3, A1, B2, E1, B3]|
    |123 |B2     |    |[B1, A3, A1, B2, E1, B3]|
    |123 |B3     |    |[B1, A3, A1, B2, E1, B3]|
    |321 |C1     |    |[C4, C3, D1, C1, C2]    |
    |321 |C2     |    |[C4, C3, D1, C1, C2]    |
    |321 |C3     |    |[C4, C3, D1, C1, C2]    |
    |321 |C4     |    |[C4, C3, D1, C1, C2]    |
    |321 |D1     |    |[C4, C3, D1, C1, C2]    |
    +----+-------+----+------------------------+
    

    The reason, why the set_of_servico is casted to string type, is for the joining. To do the joining, we can check if cod_1 and cod_2 in set_of_servico

    result = table_a_agg.select('code', 'servico', 'set_of_servico').alias('a')\
        .join(table_b.alias('b'), 
              [func.col('a.set_of_servico').contains(func.col('b.cod_1')), func.col('a.set_of_servico').contains(func.col('b.cod_2'))], 
              how='left')
    +----+-------+------------------------+-----+-----+----+
    |code|servico|set_of_servico          |cod_1|cod_2|rule|
    +----+-------+------------------------+-----+-----+----+
    |123 |A1     |[B1, A3, A1, B2, E1, B3]|E1   |A1   |aaa |
    |123 |E1     |[B1, A3, A1, B2, E1, B3]|E1   |A1   |aaa |
    |123 |A3     |[B1, A3, A1, B2, E1, B3]|E1   |A1   |aaa |
    |123 |B1     |[B1, A3, A1, B2, E1, B3]|E1   |A1   |aaa |
    |123 |B2     |[B1, A3, A1, B2, E1, B3]|E1   |A1   |aaa |
    |123 |B3     |[B1, A3, A1, B2, E1, B3]|E1   |A1   |aaa |
    |321 |C1     |[C4, C3, D1, C1, C2]    |null |null |null|
    |321 |C2     |[C4, C3, D1, C1, C2]    |null |null |null|
    |321 |C3     |[C4, C3, D1, C1, C2]    |null |null |null|
    |321 |C4     |[C4, C3, D1, C1, C2]    |null |null |null|
    |321 |D1     |[C4, C3, D1, C1, C2]    |null |null |null|
    +----+-------+------------------------+-----+-----+----+
    

    Then we can check if servico equal to cod_1 and cod_2 to filter out the row that have joining result, but actually it's not in the combination:

    result = result\
        .select(
            'code',
            'servico',
            func.when((func.col('servico')==func.col('cod_1'))|(func.col('servico')==func.col('cod_2')), func.col('rule')).otherwise(func.lit(None)).alias('rule')
        )
    result.show(100, False)
    +----+-------+----+
    |code|servico|rule|
    +----+-------+----+
    |123 |A1     |aaa |
    |123 |E1     |aaa |
    |123 |A3     |null|
    |123 |B1     |null|
    |123 |B2     |null|
    |123 |B3     |null|
    |321 |C1     |null|
    |321 |C2     |null|
    |321 |C3     |null|
    |321 |C4     |null|
    |321 |D1     |null|
    +----+-------+----+