Search code examples
pythonapache-sparkpysparkspark-graphxgraphframes

How to Get Connected Component with Graphframes in Pyspark and Raw Data in Spark Dataframe?


I have a spark data frame which looks like below:

+--+-----+---------+
|id|phone|  address|
+--+-----+---------+
| 0|  123| james st|
| 1|  177|avenue st|
| 2|  123|spring st|
| 3|  999|avenue st|
| 4|  678|  5th ave|
+--+-----+---------+

I am trying to use graphframes package to identify the connected component of ids using phone and address from above spark data frame. So this data frame can be treated as vertices data frame of the graph.

I am wondering what would be the optimal approach creating the edges data frame of the graph to feed into the connectedComponents() function in graphframes?

Ideally, the edges data frame should look like below:

+---+---+------------+
|src|dst|relationship|
+---+---+------------+
| 0 |  2|  same_phone|
| 1 |  3|same_address|
+---+---+------------+

Finally, the connectedComponents() results should be like below. id 0 & 1 are in the same component based on the same_phone relationship and 1 & 3 based on the same_address relationship. Then, this would leave 4 as another component which has no connection with other ids.

+---+-------------------+
|id |connected_component|
+---+-------------------+
|0  |1                  |
|1  |2                  |
|2  |1                  |
|3  |2                  |
|4  |3                  |
+---+-------------------+

Thanks in advance!


Solution

  • from functools import reduce
    
    edges = reduce(
        lambda x, y: x.union(y),
        [df.alias('t1')
           .join(df.alias('t2'), c)
           .filter('t1.id < t2.id')
           .selectExpr('t1.id src', 't2.id dst', "'same_%s' relationship"% c) for c in df.columns[1:]
        ]
    )
    
    edges.show()
    +---+---+------------+
    |src|dst|relationship|
    +---+---+------------+
    |  0|  2|  same_phone|
    |  1|  3|same_address|
    +---+---+------------+
    
    import pyspark.sql.functions as F
    from pyspark.sql.window import Window
    
    connect = edges.select(
        F.array_sort(F.array('src', 'dst')).alias('arr')
    ).distinct().union(
        df.join(edges, (df.id == edges.src) | (df.id == edges.dst), 'anti').select(F.array('id'))
    ).withColumn(
        'connected_component', 
        F.row_number().over(Window.orderBy('arr'))
    ).select(F.explode('arr').alias('id'), 'connected_component')
    
    connect.show()
    +---+-------------------+
    | id|connected_component|
    +---+-------------------+
    |  0|                  1|
    |  2|                  1|
    |  1|                  2|
    |  3|                  2|
    |  4|                  3|
    +---+-------------------+