0

I have a relatively shallow, directed, acyclic graph represented in GraphFrames (a large number of nodes, mainly on disjunct subgraphs). I want to propagate the id of the root nodes (nodes without incoming edges) to all nodes downstream. To achieve this, I chose the pregel algorithm. This process should converge once the passed messages don't change, however the process keeps going until the max iteration is reached.

This a model of the problem:

data = [
    ('v1', 'v1'),
    ('v3', 'v1'),
    ('v2', 'v1'),
    ('v4', 'v2'),
    ('v4', 'v5'),
    ('v5', 'v5'),
    ('v6', 'v4'),
]

df = spark.createDataFrame(data, ['variantId', 'explained']).persist()

# Create nodes:
nodes = (
    df.select(
        f.col('variantId').alias('id'),
        f.when(f.col('variantId') == f.col('explained'), f.col('variantId')).alias('origin_root')
    )
    .distinct()
)

# Create edges:
edges = (
    df
    .filter(f.col('variantId')!=f.col('explained'))
    .select(
        f.col('variantId').alias('dst'),
        f.col('explained').alias('src'),
        f.lit('explains').alias('edgeType')
    )
    .distinct()
)

# Converting into a graphframe graph:
graph = GraphFrame(nodes, edges)

The graph will look like this:

enter image description here

I want to propagate

  • [v1] => v2 and v3,
  • [v1, v5] => v4 and v6.

To do this I wrote the following function:

maxiter = 3
(
    graph.pregel
    .setMaxIter(maxiter) 
    # New column for the resolved roots:
    .withVertexColumn(
        "resolved_roots", 
        # The value is initialized by the original root value:
        f.when(
            f.col('origin_root').isNotNull(), 
            f.array(f.col('origin_root'))
        ).otherwise(f.array()),
        # When new value arrives to the node, it gets merged with the existing list:
        f.when(
            Pregel.msg().isNotNull(), 
            f.array_union(Pregel.msg(), f.col('resolved_roots'))
        ).otherwise(f.col("resolved_roots"))
    )
    # We need to reinforce the message in both direction:
    .sendMsgToDst(Pregel.src("resolved_roots"))
    # Once the message is delivered it is updated with the existing list of roots at the node:
    .aggMsgs(f.flatten(f.collect_list(Pregel.msg())))
    .run()
    .orderBy( 'id')
    .show()
)

It returns:

+---+-----------+--------------+
| id|origin_root|resolved_roots|
+---+-----------+--------------+
| v1|         v1|          [v1]|
| v2|       null|          [v1]|
| v3|       null|          [v1]|
| v4|       null|      [v1, v5]|
| v5|         v5|          [v5]|
| v6|       null|      [v1, v5]|
+---+-----------+--------------+

Although all the nodes now have root information, which stays the same, if we increase the max iteration number to 100, the process just keeps going.

The questions:

  • Why this process won't converge?
  • What can I do to make sure it converges?
  • Is this the right approach to achieve this goal?

Any helpful comment is highly appreciated, I'm absolutely new to graphs.

SDani
  • 79
  • 5

0 Answers0