Logic:
In normal course, BFS is recursive. But here we cannot have it recursive because if we start with recursion, then it will cover all nodes from one side (start or end) and will only stop if it is not able to find the end or finds the end.
So in order to do a bidirectional search, the logic will be explained with the example below:
/*
Let's say this is the graph
2------5------8
/ |
/ |
/ |
1---3------6------9
\ |
\ |
\ |
4------7------10
We want to find the path between nodes 1 and 9. In order to do this we will need 2 DS, one for recording the path form beginning and other from end:*/
ArrayList<HashMap<Integer, LinkedList<Node<Integer>>>> startTrav = new ArrayList<>();
ArrayList<HashMap<Integer, LinkedList<Node<Integer>>>> endTrav = new ArrayList<>();
/*Before starting the loop, initialise these with the values shown below:
startTrav --> index=0 --> <1, {1}>
endTrav --> index=0 --> <9, {9}>
Note here that in the HashMap, the key is the node that we have reached and the value is a linkedList containing the path used to reach to that node.
Now inside the loop we will start traversal on startTrav 1st. We will traverse it from index 0 to 0, and while traversing what ever children are there for the node under process, we will add in startTrav. So startTrav will transform like:
startTrav --> index=0 --> <1, {1}>
startTrav --> index=1 --> <2, {1,2}>
startTrav --> index=2 --> <3, {1,3}>
startTrav --> index=3 --> <4, {1,4}>
Now we will check for collision, i.e if either of nodes that we have covered in startTrav are found in endTrav (i.e if either of 1,2,3,4 is present in endTrav's list = 9). The answer is no, so continue loop.
Now do the same from endTrav
endTrav --> index=0 --> <9, {9}>
endTrav --> index=1 --> <8, {9,8}>
endTrav --> index=2 --> <6, {9,6}>
endTrav --> index=3 --> <10, {9,10}>
Now again we will check for collision, i.e if either of nodes that we have covered in startTrav are found in endTrav (i.e if either of 1,2,3,4 is present in endTrav's list = 9,8,6,10). The answer is no so continue loop.
// end of 1st iteration of while loop
// beginning of 2nd iteration of while loop
startTrav --> index=0 --> <1, {1}>
startTrav --> index=1 --> <2, {1,2}>
startTrav --> index=2 --> <3, {1,3}>
startTrav --> index=3 --> <4, {1,4}>
startTrav --> index=4 --> <5, {1,2,5}>
startTrav --> index=5 --> <6, {1,3,6}>
startTrav --> index=6 --> <7, {1,4,7}>
Now again we will check for collision, i.e if either of nodes that we have covered in startTrav are found in endTrav (i.e if either of 1,2,3,4,5,6,7 is present in endTrav's list = 9,8,6,10). The answer is yes. Colission has occurred on node 6. Break the loop now.
Now pick the path to 6 from startTrav and pick the path to 6 from endTrav and merge the 2.*/
Code for this is as below:
class Node<T> {
public T value;
public LinkedList<Node<T>> nextNodes = new LinkedList<>();
}
class Graph<T>{
public HashMap<Integer, Node<T>> graph=new HashMap<>();
}
public class BiDirectionalBFS {
public LinkedList<Node<Integer>> findPath(Graph<Integer> graph, int startNode, int endNode) {
if(!graph.graph.containsKey(startNode) || !graph.graph.containsKey(endNode)) return null;
if(startNode==endNode) {
LinkedList<Node<Integer>> ll = new LinkedList<>();
ll.add(graph.graph.get(startNode));
return ll;
}
ArrayList<HashMap<Integer, LinkedList<Node<Integer>>>> startTrav = new ArrayList<>();
ArrayList<HashMap<Integer, LinkedList<Node<Integer>>>> endTrav = new ArrayList<>();
boolean[] traversedNodesFromStart = new boolean[graph.graph.size()];
boolean[] traversedNodesFromEnd = new boolean[graph.graph.size()];
addDetailsToAL(graph, startNode, startTrav, traversedNodesFromStart, null);
addDetailsToAL(graph, endNode, endTrav, traversedNodesFromEnd, null);
int collision = -1, startIndex=0, endIndex=0;
while (startTrav.size()>startIndex && endTrav.size()>endIndex) {
// Cover all nodes in AL from start and add new
int temp=startTrav.size();
for(int i=startIndex; i<temp; i++) {
recordAllChild(graph, startTrav, i, traversedNodesFromStart);
}
startIndex=temp;
//check collision
if((collision = checkColission(traversedNodesFromStart, traversedNodesFromEnd))!=-1) {
break;
}
//Cover all nodes in AL from end and add new
temp=endTrav.size();
for(int i=endIndex; i<temp; i++) {
recordAllChild(graph, endTrav, i, traversedNodesFromEnd);
}
endIndex=temp;
//check collision
if((collision = checkColission(traversedNodesFromStart, traversedNodesFromEnd))!=-1) {
break;
}
}
LinkedList<Node<Integer>> pathFromStart = null, pathFromEnd = null;
if(collision!=-1) {
for(int i =0;i<traversedNodesFromStart.length && (pathFromStart==null || pathFromEnd==null); i++) {
if(pathFromStart==null && startTrav.get(i).keySet().iterator().next()==collision) {
pathFromStart=startTrav.get(i).get(collision);
}
if(pathFromEnd==null && endTrav.get(i).keySet().iterator().next()==collision) {
pathFromEnd=endTrav.get(i).get(collision);
}
}
pathFromEnd.removeLast();
ListIterator<Node<Integer>> li = pathFromEnd.listIterator();
while(li.hasNext()) li.next();
while(li.hasPrevious()) {
pathFromStart.add(li.previous());
}
return pathFromStart;
}
return null;
}
private void recordAllChild(Graph<Integer> graph, ArrayList<HashMap<Integer, LinkedList<Node<Integer>>>> listToAdd, int index, boolean[] traversedNodes) {
HashMap<Integer, LinkedList<Node<Integer>>> record=listToAdd.get(index);
Integer recordKey = record.keySet().iterator().next();
for(Node<Integer> child:graph.graph.get(recordKey).nextNodes) {
if(traversedNodes[child.value]!=true) { addDetailsToAL(graph, child.getValue(), listToAdd, traversedNodes, record.get(recordKey));
}
}
}
private void addDetailsToAL(Graph<Integer> graph, Integer node, ArrayList<HashMap<Integer, LinkedList<Node<Integer>>>> startTrav,
boolean[] traversalArray, LinkedList<Node<Integer>> oldLLContent) {
LinkedList<Node<Integer>> ll = oldLLContent==null?new LinkedList<>() : new LinkedList<>(oldLLContent);
ll.add(graph.graph.get(node));
HashMap<Integer, LinkedList<Node<Integer>>> hm = new HashMap<>();
hm.put(node, ll);
startTrav.add(hm);
traversalArray[node]=true;
}
private int checkColission(boolean[] start, boolean[] end) {
for (int i=0; i<start.length; i++) {
if(start[i] && end[i]) {
return i;
}
}
return -1;
}
}
A much more neater and easier to understand approach can be though Arrays. We will replace the complex DS :
ArrayList<HashMap<Integer, LinkedList<Node<Integer>>>>
with a simple
LinkedList<Node<Integer>>[]
Here, the index of the LL will define the numeric value of the node. So if the node has value 7, then the path to reach 7 will be stored at index 7 in the array. Also we will remove the boolean arrays for finding which path to which element is found as that can be achieved with our linkedList array itself. We will add 2
LinkedList<Node<Integer>>
which will be used for storing the children as in case of level order traversal of tree. Lastly, we for storing the path for traversal from end, we will store it in reverse order, so that while merging, we do not need to reverse the elements from the 2nd array. Code for this goes as below:
class Node<T> {
public T value;
public LinkedList<Node<T>> nextNodes = new LinkedList<>();
}
class Graph<T>{
public HashMap<Integer, Node<T>> graph=new HashMap<>();
}
public class BiDirectionalBFS {
private LinkedList<Node<Integer>> findPathUsingArrays(Graph<Integer> graph, int startNode, int endNode) {
if(!graph.graph.containsKey(startNode) || !graph.graph.containsKey(endNode)) return null;
if(startNode==endNode) {
LinkedList<Node<Integer>> ll = new LinkedList<>();
ll.add(graph.graph.get(startNode));
return ll;
}
LinkedList<Node<Integer>>[] startTrav = new LinkedList[graph.graph.size()];
LinkedList<Node<Integer>>[] endTrav = new LinkedList[graph.graph.size()];
LinkedList<Node<Integer>> traversedNodesFromStart = new LinkedList<>();
LinkedList<Node<Integer>> traversedNodesFromEnd = new LinkedList<>();
addToDS(graph, traversedNodesFromStart, startTrav, startNode);
addToDS(graph, traversedNodesFromEnd, endTrav, endNode);
int collision = -1;
while (traversedNodesFromStart.size()>0 && traversedNodesFromEnd.size()>0) {
// Cover all nodes in LL from start and add new
recordAllChild(traversedNodesFromStart.size(), traversedNodesFromStart, startTrav, true);
//check collision
if((collision = checkColission(startTrav, endTrav))!=-1) {
break;
}
//Cover all nodes in LL from end and add new
recordAllChild(traversedNodesFromEnd.size(), traversedNodesFromEnd, endTrav, false);
//check collision
if((collision = checkColission(startTrav, endTrav))!=-1) {
break;
}
}
if(collision!=-1) {
endTrav[collision].removeFirst();
startTrav[collision].addAll(endTrav[collision]);
return startTrav[collision];
}
return null;
}
private void recordAllChild(int temp, LinkedList<Node<Integer>> traversedNodes, LinkedList<Node<Integer>>[] travArr, boolean addAtLast) {
while (temp>0) {
Node<Integer> node = traversedNodes.remove();
for(Node<Integer> child : node.nextNodes) {
if(travArr[child.value]==null) {
traversedNodes.add(child);
LinkedList<Node<Integer>> ll=new LinkedList<>(travArr[node.value]);
if(addAtLast) {
ll.add(child);
} else {
ll.addFirst(child);
}
travArr[child.value]=ll;
traversedNodes.add(child);
}
}
temp--;
}
}
private int checkColission(LinkedList<Node<Integer>>[] startTrav, LinkedList<Node<Integer>>[] endTrav) {
for (int i=0; i<startTrav.length; i++) {
if(startTrav[i]!=null && endTrav[i]!=null) {
return i;
}
}
return -1;
}
private void addToDS(Graph<Integer> graph, LinkedList<Node<Integer>> traversedNodes, LinkedList<Node<Integer>>[] travArr, int node) {
LinkedList<Node<Integer>> ll = new LinkedList<>();
ll.add(graph.graph.get(node));
travArr[node]=ll;
traversedNodes.add(graph.graph.get(node));
}
}
Hope it helps.
Happy coding.