Search code examples
pythonpyspark

Joining 2 dataframes in pyspark where one column can have duplicates


I have a pyspark dataframe with 2 columns, ID and condition. ID corresponds to a user, the user can have multiple conditions. I want to find out those users who have condition A and condition B, how to do that?

Sample dataframe:

ID CONDITION
1 A
2 B
1 B
1 C
2 C
2 D
1 E

If I want to get users who have A,B as conditions, I need only 1 as the output. If I want to get users who have C,D as conditions, I need only 2 as the output. If I want to get users who have B,C as conditions, I need both 1 and 2 as outputs.

These requirements are represented in a dataframe as below:

sl_no conditions
s1 [A,B]
s2 [C,D]
s3 [B,C]

My attempt is as following:

       df1=df.groupBy('USER_ID').agg(F.collect_set('CONDITION').alias('conditions'))

       df2=conditions_data

       result=df1.join(df2,F.array_intersection(df1['conditions'],df2['conditions'])==df2['conditions'])

However, I see some inconsistencies in the results. Also, wanted to know if there's a better way to do this.


Solution

  • Collect the unique conditions per ID

    users = df1.groupby('ID').agg(F.collect_set('CONDITION').alias('CONDITION'))
    
    # +---+------------+
    # | ID|   CONDITION|
    # +---+------------+
    # |  1|[C, E, B, A]|
    # |  2|   [C, B, D]|
    # +---+------------+
    

    Join the conditions_data with the users dataframe where the join condition must satisfy the set membership test

    cond = F.expr("size(array_intersect(conditions, CONDITION)) = size(conditions)")
    result = conditions_data.join(users, on=cond, how='left')
    
    # +-----+----------+---+------------+
    # |sl_no|conditions| ID|   CONDITION|
    # +-----+----------+---+------------+
    # |   s1|    [A, B]|  1|[C, E, B, A]|
    # |   s2|    [C, D]|  2|   [C, B, D]|
    # |   s3|    [B, C]|  1|[C, E, B, A]|
    # |   s3|    [B, C]|  2|   [C, B, D]|
    # +-----+----------+---+------------+
    

    Collect the users for each unique row in condition_data

    result = result.groupby(*conditions_data.columns).agg(F.collect_list('ID').alias('ID'))
    
    # +-----+----------+------+
    # |sl_no|conditions|    ID|
    # +-----+----------+------+
    # |   s1|    [A, B]|   [1]|
    # |   s2|    [C, D]|   [2]|
    # |   s3|    [B, C]|[1, 2]|
    # +-----+----------+------+