Try something like this:
graph.edges.filter(_.srcId == x).map(e => (e.dstId, null)).join(
graph.collectNeighborIds(EdgeDirection.Either)
).flatMap{t => t._2._2}.collect.toSet
If you want to go deeper than this, I would use something like the Pregel API. Essentially, it lets you repeatedly send messages from node to node and aggregate the results.
Edit: Pregel Solution
I finally got the the iterations to stop on their own. Edits below. Given this graph:
graph.vertices.collect
res46: Array[(org.apache.spark.graphx.VertexId, Array[Long])] = Array((4,Array()), (8,Array()), (1,Array()), (9,Array()), (5,Array()), (6,Array()), (2,Array()), (3,Array()), (7,Array()))
graph.edges.collect
res47: Array[org.apache.spark.graphx.Edge[Double]] = Array(Edge(1,2,0.0), Edge(2,3,0.0), Edge(3,4,0.0), Edge(5,6,0.0), Edge(6,7,0.0), Edge(7,8,0.0), Edge(8,9,0.0), Edge(4,2,0.0), Edge(6,9,0.0), Edge(7,9,0.0))
We are going to send messages of the type Array[Long]
-- an array of all the VertexIds
of connected nodes. Messages are going to go upstream -- the dst
will send the src
its VertexId
along with all of the other downstream VertexIds
. If the upstream node already knows about the connection, no message will be sent. Eventually, every node knows about every connected node and no more messages will be sent.
First we define our vprog
. According to the docs:
the user-defined vertex program which runs on each vertex and receives
the inbound message and computes a new vertex value. On the first
iteration the vertex program is invoked on all vertices and is passed
the default message. On subsequent iterations the vertex program is
only invoked on those vertices that receive messages.
def vprog(id: VertexId, orig: Array[Long], newly: Array[Long]) : Array[Long] = {
(orig ++ newly).toSet.toArray
}
Then we define our sendMsg
-- edited: swapped src
& dst
a user supplied function that is applied to out edges of vertices that
received messages in the current iteration
def sendMsg(trip: EdgeTriplet[Array[Long],Double]) : Iterator[(VertexId, Array[Long])] = {
if (trip.srcAttr.intersect(trip.dstAttr ++ Array(trip.dstId)).length != (trip.dstAttr ++ Array(trip.dstId)).toSet.size) {
Iterator((trip.srcId, (Array(trip.dstId) ++ trip.dstAttr).toSet.toArray ))
} else Iterator.empty }
Next our mergeMsg
:
a user supplied function that takes two incoming messages of type A
and merges them into a single message of type A. This function must be
commutative and associative and ideally the size of A should not
increase.
Unfortunately, we're going to break the rule in the last sentence above:
def mergeMsg(a: Array[Long], b: Array[Long]) : Array[Long] = {
(a ++ b).toSet.toArray
}
Then we run pregel
-- edited: removed maxIterations
, defaults to Int.MaxValue
val result = graph.pregel(Array[Long]())(vprog, sendMsg, mergeMsg)
And you can look at the results:
result.vertices.collect
res48: Array[(org.apache.spark.graphx.VertexId, Array[Long])] = Array((4,Array(4, 2, 3)), (8,Array(8, 9)), (1,Array(1, 2, 3, 4)), (9,Array(9)), (5,Array(5, 6, 9, 7, 8)), (6,Array(6, 7, 9, 8)), (2,Array(2, 3, 4)), (3,Array(3, 4, 2)), (7,Array(7, 8, 9)))