I have a relatively shallow, directed, acyclic graph represented in GraphFrames (a large number of nodes, mainly on disjunct subgraphs). I want to propagate the id of the root nodes (nodes without incoming edges) to all nodes downstream. To achieve this, I chose the pregel algorithm. This process should converge once the passed messages don't change, however the process keeps going until the max iteration is reached.
This a model of the problem:
data = [
('v1', 'v1'),
('v3', 'v1'),
('v2', 'v1'),
('v4', 'v2'),
('v4', 'v5'),
('v5', 'v5'),
('v6', 'v4'),
]
df = spark.createDataFrame(data, ['variantId', 'explained']).persist()
# Create nodes:
nodes = (
df.select(
f.col('variantId').alias('id'),
f.when(f.col('variantId') == f.col('explained'), f.col('variantId')).alias('origin_root')
)
.distinct()
)
# Create edges:
edges = (
df
.filter(f.col('variantId')!=f.col('explained'))
.select(
f.col('variantId').alias('dst'),
f.col('explained').alias('src'),
f.lit('explains').alias('edgeType')
)
.distinct()
)
# Converting into a graphframe graph:
graph = GraphFrame(nodes, edges)
The graph will look like this:
I want to propagate
- [v1] => v2 and v3,
- [v1, v5] => v4 and v6.
To do this I wrote the following function:
maxiter = 3
(
graph.pregel
.setMaxIter(maxiter)
# New column for the resolved roots:
.withVertexColumn(
"resolved_roots",
# The value is initialized by the original root value:
f.when(
f.col('origin_root').isNotNull(),
f.array(f.col('origin_root'))
).otherwise(f.array()),
# When new value arrives to the node, it gets merged with the existing list:
f.when(
Pregel.msg().isNotNull(),
f.array_union(Pregel.msg(), f.col('resolved_roots'))
).otherwise(f.col("resolved_roots"))
)
# We need to reinforce the message in both direction:
.sendMsgToDst(Pregel.src("resolved_roots"))
# Once the message is delivered it is updated with the existing list of roots at the node:
.aggMsgs(f.flatten(f.collect_list(Pregel.msg())))
.run()
.orderBy( 'id')
.show()
)
It returns:
+---+-----------+--------------+
| id|origin_root|resolved_roots|
+---+-----------+--------------+
| v1| v1| [v1]|
| v2| null| [v1]|
| v3| null| [v1]|
| v4| null| [v1, v5]|
| v5| v5| [v5]|
| v6| null| [v1, v5]|
+---+-----------+--------------+
Although all the nodes now have root information, which stays the same, if we increase the max iteration number to 100, the process just keeps going.
The questions:
- Why this process won't converge?
- What can I do to make sure it converges?
- Is this the right approach to achieve this goal?
Any helpful comment is highly appreciated, I'm absolutely new to graphs.