Initially, I thought this was an easy problem, but I just can't figure it out. Here is a simplified example. I have 8 different people buying some items from a store. Afterwards I want to look at all the items and sort them into groups so that each overlapping initial shopping goes into the same new group.
input would look like this:
| person | items |
| ------ | ----------------- |
| 1| [B, A]|
| 2| [C, A]|
| 3| [E, I, D]|
| 4| [F, G]|
| 5| [F]|
| 6| [F, H]|
| 7| [J, A]|
| 8| [K, J]|
and output should look like this:
| bag | items |
| ----| -------------- |
| 1| [A, B, C, J, K]|
| 2| [E, I, D]|
| 3| [F, G, H]|
I tried crossJoin, array_intersect, concatenating arrays and also doing all this in several loops. But while it eventually leads to bag 1 ([A, B, C, J, K]), I cannot easily identify the smaller bags and at what point the group is complete. Am I overthinking this?
We can represent this DataFrame as an undirected (disconnected) graph where each item is a node, and we draw an edge to another node based on whether the two items were bought by the same person, or if two items were bought by two different people but have a common neighboring item both bought by by both people. For example, for person 1 and person 2, we would draw edges between nodes B and A, C and A, and B and C (because A is a neighbor of B for person 1, and A is a neighbor of C for person 2).
To get all the nodes connected by an edge, we can explode the input DataFrame and self-join on person
(and df_edges
will have fields src
and dst
to identify nodes connected by an edge, which is needed to create a GraphFrame
). The nodes are all distinct items
in the exploded input DataFrame.
From df_nodes
and df_edges
, we can create a GraphFrame
object using the graphframes
library (which is compatible with pyspark dataframes), and use the built-in GraphFrame method to find all connected components.
Below is the reproducible code:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, ArrayType
import pyspark.sql.functions as F
from graphframes import GraphFrame
# Initializing Spark session
spark = SparkSession.builder.appName("TestGroup").getOrCreate()
## this is needed to run GraphFrame methods later
sc = spark.sparkContext
sc.setCheckpointDir('test')
df = spark.createDataFrame(
data=[
("1", ["B","A"]),
("2", ["C","A"]),
("3", ["E","I","D"]),
("4", ["F","G"]),
("5", ["F"]),
("6", ["F","H"]),
("7", ["J","A"]),
("8", ["K","J"]),
],
schema=StructType(
[
StructField("person", StringType()),
StructField("items", ArrayType(StringType()), False),
]
),
)
## create edges and vertices DataFrames
df_exploded = df.select("person", F.explode(F.col("items")).alias("item"))
df_edges = df_exploded.alias("df1").join(
df_exploded.alias("df2"), F.col("df1.person") ==
F.col("df2.person")
).select(
F.col("df1.item").alias("src"),
F.col("df2.item").alias("dst")
).filter(F.col("src") != F.col("dst")).distinct()
df_nodes = df_exploded.select(F.col("item").alias("id")).distinct()
# Create GraphFrame and find connected items
g = GraphFrame(df_nodes, df_edges)
connected_items = g.connectedComponents()
output = connected_items.groupby('component').agg(F.collect_list('id').alias('items'))
This is the output:
+---------+---------------+
|component| items|
+---------+---------------+
| 0|[B, A, C, J, K]|
| 5| [F, G, H]|
| 3| [E, D, I]|
+---------+---------------+