Search code examples
pysparkintersectioncollect

Update values in a column based on values of another data frame's column values in PySpark


I have two data frames in PySpark: df1

+---+-----------------+
|id1|           items1|
+---+-----------------+
|  0|     [B, C, D, E]|
|  1|        [E, A, C]|
|  2|     [F, A, E, B]|
|  3|        [E, G, A]|
|  4|  [A, C, E, B, D]|
+---+-----------------+ 

and df2:

+---+-----------------+
|id2|           items2|
+---+-----------------+
|001|              [B]|
|002|              [A]|
|003|              [C]|
|004|              [E]|
+---+-----------------+ 

I would like to create a new column in df1 that would update values in items1 column, so that it only keeps values that also appear (in any row of) items2 in df2. The result should look as follows:

+---+-----------------+----------------------+
|id1|           items1|        items1_updated|
+---+-----------------+----------------------+
|  0|     [B, C, D, E]|             [B, C, E]|
|  1|        [E, A, C]|             [E, A, C]|
|  2|     [F, A, E, B]|             [A, E, B]|
|  3|        [E, G, A]|                [E, A]|
|  4|  [A, C, E, B, D]|          [A, C, E, B]|
+---+-----------------+----------------------+

I would normally use collect() to get a list of all values in items2 column and then use a udf applied to each row in items1 to get an intersection. But the data is extremely large (over 10 million rows) and I cannot use collect() to get such list. Is there a way to do this while keeping data in a data frame format? Or some other way without using collect()?


Solution

  • The first thing you want to do is explode the values in df2.items2 so that contents of the arrays will be on separate rows:

    from pyspark.sql.functions import explode
    df2 = df2.select(explode("items2").alias("items2"))
    df2.show()
    #+------+
    #|items2|
    #+------+
    #|     B|
    #|     A|
    #|     C|
    #|     E|
    #+------+
    

    (This assumes that the values in df2.items2 are distinct- if not, you would need to add df2 = df2.distinct().)

    Option 1: Use crossJoin:

    Now you can crossJoin the new df2 back to df1 and keep only the rows where df1.items1 contains an element in df2.items2. We can achieve this using pyspark.sql.functions.array_contains and this trick that allows us to use a column value as a parameter.

    After filtering, group by id1 and items1 and aggregate using pyspark.sql.functions.collect_list

    from pyspark.sql.functions import expr, collect_list
    
    df1.alias("l").crossJoin(df2.alias("r"))\
        .where(expr("array_contains(l.items1, r.items2)"))\
        .groupBy("l.id1", "l.items1")\
        .agg(collect_list("r.items2").alias("items1_updated"))\
        .show()
    #+---+---------------+--------------+
    #|id1|         items1|items1_updated|
    #+---+---------------+--------------+
    #|  1|      [E, A, C]|     [A, C, E]|
    #|  0|   [B, C, D, E]|     [B, C, E]|
    #|  4|[A, C, E, B, D]|  [B, A, C, E]|
    #|  3|      [E, G, A]|        [A, E]|
    #|  2|   [F, A, E, B]|     [B, A, E]|
    #+---+---------------+--------------+
    

    Option 2: Explode df1.items1 and left join:

    Another option is to explode the contents of items1 in df1 and do a left join. After the join, we have to do a similar group by and aggregation as above. This works because collect_list will ignore the null values introduced by the non-matching rows

    df1.withColumn("items1", explode("items1")).alias("l")\
        .join(df2.alias("r"), on=expr("l.items1=r.items2"), how="left")\
        .groupBy("l.id1")\
        .agg(
            collect_list("l.items1").alias("items1"),
            collect_list("r.items2").alias("items1_updated")
        ).show()
    #+---+---------------+--------------+
    #|id1|         items1|items1_updated|
    #+---+---------------+--------------+
    #|  0|   [E, B, D, C]|     [E, B, C]|
    #|  1|      [E, C, A]|     [E, C, A]|
    #|  3|      [E, A, G]|        [E, A]|
    #|  2|   [F, E, B, A]|     [E, B, A]|
    #|  4|[E, B, D, C, A]|  [E, B, C, A]|
    #+---+---------------+--------------+