Search code examples
pysparkgraphgraphframes

GraphFrames Pregel doesn't converge


I have a relatively shallow, directed, acyclic graph represented in GraphFrames (a large number of nodes, mainly on disjunct subgraphs). I want to propagate the id of the root nodes (nodes without incoming edges) to all nodes downstream. To achieve this, I chose the pregel algorithm. This process should converge once the passed messages don't change, however the process keeps going until the max iteration is reached.

This a model of the problem:

data = [
    ('v1', 'v1'),
    ('v3', 'v1'),
    ('v2', 'v1'),
    ('v4', 'v2'),
    ('v4', 'v5'),
    ('v5', 'v5'),
    ('v6', 'v4'),
]

df = spark.createDataFrame(data, ['variantId', 'explained']).persist()

# Create nodes:
nodes = (
    df.select(
        f.col('variantId').alias('id'),
        f.when(f.col('variantId') == f.col('explained'), f.col('variantId')).alias('origin_root')
    )
    .distinct()
)

# Create edges:
edges = (
    df
    .filter(f.col('variantId')!=f.col('explained'))
    .select(
        f.col('variantId').alias('dst'),
        f.col('explained').alias('src'),
        f.lit('explains').alias('edgeType')
    )
    .distinct()
)

# Converting into a graphframe graph:
graph = GraphFrame(nodes, edges)

The graph will look like this:

enter image description here

I want to propagate

  • [v1] => v2 and v3,
  • [v1, v5] => v4 and v6.

To do this I wrote the following function:

maxiter = 3
(
    graph.pregel
    .setMaxIter(maxiter) 
    # New column for the resolved roots:
    .withVertexColumn(
        "resolved_roots", 
        # The value is initialized by the original root value:
        f.when(
            f.col('origin_root').isNotNull(), 
            f.array(f.col('origin_root'))
        ).otherwise(f.array()),
        # When new value arrives to the node, it gets merged with the existing list:
        f.when(
            Pregel.msg().isNotNull(), 
            f.array_union(Pregel.msg(), f.col('resolved_roots'))
        ).otherwise(f.col("resolved_roots"))
    )
    # We need to reinforce the message in both direction:
    .sendMsgToDst(Pregel.src("resolved_roots"))
    # Once the message is delivered it is updated with the existing list of roots at the node:
    .aggMsgs(f.flatten(f.collect_list(Pregel.msg())))
    .run()
    .orderBy( 'id')
    .show()
)

It returns:

+---+-----------+--------------+
| id|origin_root|resolved_roots|
+---+-----------+--------------+
| v1|         v1|          [v1]|
| v2|       null|          [v1]|
| v3|       null|          [v1]|
| v4|       null|      [v1, v5]|
| v5|         v5|          [v5]|
| v6|       null|      [v1, v5]|
+---+-----------+--------------+

Although all the nodes now have root information, which stays the same, if we increase the max iteration number to 100, the process just keeps going.

The questions:

  • Why this process won't converge?
  • What can I do to make sure it converges?
  • Is this the right approach to achieve this goal?

Any helpful comment is highly appreciated, I'm absolutely new to graphs.


Solution

  • OSS GraphFrames does not take into consideration active message count and hence just depends on number of iterations to exit.

    https://github.com/graphframes/graphframes/blob/0b4df70038a4f0ff3b4544223089084fc8742da7/src/main/scala/org/graphframes/lib/Pregel.scala#L205

    The code there looks like while (iteration <= maxIter)

    There is GraphFrames library pre-installed with Databricks ML runtimes which is not open source I guess and it probably follows the same pattern.

    If you need proper exit based on active messages, you have to use Spark Graphx Scala API as of now.

    Scala implementation has some logic to detect number of active messages and it will exit if no new active messages are generated.

    https://github.com/apache/spark/blob/86f15e4a779ec746373c78c189830cb339b07492/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala#L145C12-L145C36

    The code there looks like this : while (isActiveMessagesNonEmpty && i < maxIterations)

    I have a medium blog post explaining Pregel in Scala with some examples somehow similar to the problem in this thread.

    https://towardsdatascience.com/spark-graphx-pregel-its-not-so-complex-as-it-sounds-d196da246c73