19

Given a bst with integer values as keys how do I find the closest node to that key in a bst ? The BST is represented using a object of nodes (Java). Closest will be for eg 4,5,9 and if the key is 6 it will return 5 ..

phoenix
  • 3,531
  • 9
  • 30
  • 31

11 Answers11

22

Traverse the tree as you would to find the element. While you do that record the value that is closest to your key. Now when you didn't find a node for the key itself return the recorded value.

So if you were looking for the key 3 in the following tree you would end up on the node 6 without finding a match but your recorded value would be 2 since this was the closest key of all nodes that you had traversed (2,7,6).

                 2
              1      7
                   6   8
x4u
  • 13,877
  • 6
  • 48
  • 58
14

Here's a recursive solution in Python:

def searchForClosestNodeHelper(root, val, closestNode):
    if root is None:
        return closestNode

    if root.val == val:
        return root

    if closestNode is None or abs(root.val - val) < abs(closestNode.val - val):
        closestNode = root

    if val < root.val:
        return searchForClosestNodeHelper(root.left, val, closestNode)
    else:
        return searchForClosestNodeHelper(root.right, val, closestNode)

def searchForClosestNode(root, val):
    return searchForClosestNodeHelper(root, val, None)
Clay Schubiner
  • 133
  • 1
  • 4
11

It can be solved in O(log*n*) time.

  • If the value in a node is same as the given value, it's the closest node;
  • If the value in a node is greater than the given value, move to the left child;
  • If the value in a node is less than the given value, move to the right child.

The algorithm can be implemented with the following C++ code:

BinaryTreeNode* getClosestNode(BinaryTreeNode* pRoot, int value)
{
    BinaryTreeNode* pClosest = NULL;
    int minDistance = 0x7FFFFFFF;
    BinaryTreeNode* pNode = pRoot;
    while(pNode != NULL){
        int distance = abs(pNode->m_nValue - value);
        if(distance < minDistance){
            minDistance = distance;
            pClosest = pNode;
        }

        if(distance == 0)
            break;

        if(pNode->m_nValue > value)
            pNode = pNode->m_pLeft;
        else if(pNode->m_nValue < value)
            pNode = pNode->m_pRight;
    }

    return pClosest;
}

You may visit my blog for more details.

Harry He
  • 1,795
  • 16
  • 12
  • Agreed, this iterative solution will run on average O(log(n)) time and O(1) space, with worst case O(n) time and O(1) space. The recursive solution would be O(log(n) time and O(log(n)) space which is why the iterative solution is nice as it takes up less space from avoiding adding recursive calls to the call stack. – Tanner Dolby Jan 23 '22 at 23:16
11

Traverse takes O(n) time. Can we proceed it in top-bottom? like this recursive code:

Tnode * closestBST(Tnode * root, int val){
    if(root->val == val)
        return root;
    if(val < root->val){
        if(!root->left)
            return root;
        Tnode * p = closestBST(root->left, val);
        return abs(p->val-val) > abs(root->val-val) ? root : p;
    }else{
        if(!root->right)
            return root;
        Tnode * p = closestBST(root->right, val);
        return abs(p->val-val) > abs(root->val-val) ? root : p;
    }   
    return null;
}
gopher_rocks
  • 133
  • 7
  • The function can return null. If p = null, p->val is invalid. In fact, the last line "return null" is unreachable. – vic Oct 07 '16 at 06:01
  • 1
    This appears to actually be O(h) where h is the height of the BST vs O(n) – thun Jan 17 '17 at 17:14
3

The problem with the approach "left right traversal and finding the closest" is that it depends over the sequence in which elements were entered to create BST. If we search 11 for the BST sequence 22, 15, 16, 6,14,3,1,90, the above method will return 15 while the correct answer is 14. The only method should be using recursion to traverse all the nodes, returning the closest one as the result of the recursive function. This'll give us the closest value

Pramod
  • 39
  • 1
  • That is simply not true. Teng Teng's answer above: http://stackoverflow.com/a/6276239/906751 will work for the case you described, and does not traverse every node in the tree. – KSletmoe May 17 '16 at 22:28
0

This can be done using a Queue and a ArrayList. Queue will be used to perform a breadth first search on the tree. ArrayList will be used to store the element of the tree in breadth first order. Here is the code to implement the same

Queue queue = new LinkedList();
ArrayList list = new ArrayList();
int i =0;
public Node findNextRightNode(Node root,int key)
{
    System.out.print("The breadth first search on Tree : \t");      
    if(root == null)
        return null;

    queue.clear();
    queue.add(root);

    while(!queue.isEmpty() )
    {
        Node node = (Node)queue.remove();
        System.out.print(node.data + " ");
        list.add(node);
        if(node.left != null) queue.add(node.left);
        if(node.right !=null) queue.add(node.right);            
    }

    Iterator iter = list.iterator();
    while(iter.hasNext())
        {
            if(((Node)iter.next()).data == key)
            {
                return ((Node)iter.next());
            }               
        }

    return null;
}
Shivam Verma
  • 426
  • 4
  • 10
0
void closestNode(Node root, int k , Node result) {
    if(root == null) 
    {
       return;      //currently result is null , so it  will be the result
    }
    if(result == null || Math.abs(root.data - k) < Math.abs(result.data - k) )
    {
      result == root;
    }
    if(k < root.data)
    {
    closestNode(root.left, k, result)
    } 
    else 
    {
        closestNode(root.right, k, result);
    }

}
Victor
  • 761
  • 8
  • 7
0

Below one works with different samples which I have.

public Node findNearest(Node root, int k) {
    if (root == null) {
        return null;
    }
    int minDiff = 0;
    Node minAt = root;
    minDiff = Math.abs(k - root.data);

    while (root != null) {
        if (k == root.data) {
            return root;
        }
        if (k < root.data) {
            minAt = updateMin(root, k, minDiff, minAt);
            root = root.left;
        } else if (k > root.data) {
            minAt = updateMin(root, k, minDiff, minAt);
            root = root.right;
        }

    }
    return minAt;
}

private Node updateMin(Node root, int k, int minDiff, Node minAt) {
    int curDif;
    curDif = Math.abs(k - root.data);
    if (curDif < minDiff) {
        minAt = root;
    }
    return minAt;
}
0

Here is the full Java code to find the closest element in a BST.

        package binarytree;

        class BSTNode {
            BSTNode left,right;
            int data;

            public BSTNode(int data) {
                this.data = data;
                this.left = this.right = null;
            }
        }

        class BST {
            BSTNode root;

            public static BST createBST() {
                BST bst = new BST();
                bst.root = new BSTNode(9);
                bst.root.left = new BSTNode(4);
                bst.root.right = new BSTNode(17);

                bst.root.left.left = new BSTNode(3);
                bst.root.left.right= new BSTNode(6);

                bst.root.left.right.left= new BSTNode(5);
                bst.root.left.right.right= new BSTNode(7);

                bst.root.right.right = new BSTNode(22);
                bst.root.right.right.left = new BSTNode(20);

                return bst;
            }
        }

        public class ClosestElementInBST {
            public static void main(String[] args) {
                BST bst = BST.createBST();
                int target = 18;
                BSTNode currentClosest = null;
                BSTNode closestNode = findClosestElement(bst.root, target, currentClosest);

                if(closestNode != null) {
                    System.out.println("Found closest node: " + closestNode.data);
                }
                else {
                    System.out.println("Couldn't find closest node.");
                }
            }

            private static BSTNode findClosestElement(BSTNode node, int target, BSTNode currentClosest) {
                if(node == null) return currentClosest;

                if(currentClosest == null || 
                        (currentClosest != null && (Math.abs(currentClosest.data - target) > Math.abs(node.data - target)))) {
                    currentClosest = node;
                }

               if(node.data == target) return node;

                else if(target < node.data) {
                    return findClosestElement(node.left, target, currentClosest);
                }

                else { //target > node.data
                    currentClosest = node;
                    return findClosestElement(node.right, target, currentClosest);
                }
            }

        }
Joe
  • 326
  • 3
  • 11
0

Here is the working solution in java which uses the characteristics of BST and additional integer to store minimum difference

public class ClosestValueBinaryTree {
        static int closestValue;

        public static void closestValueBST(Node22 node, int target) {
            if (node == null) {
                return;
            }
            if (node.data - target == 0) {
                closestValue = node.data;
                return;
            }
            if (Math.abs(node.data - target) < Math.abs(closestValue - target)) {
                closestValue = node.data;
            }
            if (node.data - target < 0) {
                closestValueBST(node.right, target);
            } else {
                closestValueBST(node.left, target);
            }
        }
    }

Run time complexity - O(logN)

Space time complexity - O(1)

Aarish Ramesh
  • 6,745
  • 15
  • 60
  • 105
0

Given the fact were provided a Binary Search Tree, the efficient approach will be to traverse the tree and compare the root nodes absolute difference (distance) from the target while keeping track of nodes with lesser "distances" to update the closest value when we encounter a closer node to our target value. Next, we can start comparing the current nodes value to the target, if its less than the target we want to search the right sub-tree for values that are greater than or equal to the root node, if the current nodes value is greater than the target we want to search the left sub-tree for values that are strictly less than the root node.

Doing this, we can eliminate half of the BST (on average) at each step, meaning we traverse the left sub-tree (eliminating the right half) or we traverse the right sub-tree (eliminating the left half) while keeping track of the closest node and updating it when we find nodes closer to the target.

For the BST you've provided 5 4 9, it indeed satisfies the requirements of a BST:

  • all values left of the root node are strictly less than the root node
  • all values right of the root node are greater than or equal to the root node
  • each parent node can only have a maximum of two children
  5
 / \
4   9

and for context a node in the BST will have structure:

struct Node {
  int data;
  Node *left;
  Node *right;
  Node() { data = 0; left = right = nullptr; };
  Node(int val) { data = val; left = right = nullptr; };
}

Below are a few C++ solutions but the logic can be utilized with Java syntax quite easily.

Recursive approach performs on average at O(log(n)) time and O(log(n)) space as we are recursively calling minDiffHelper and those calls or "frames" are added to the call stack which takes up space.

// On average: O(log(n)) time and O(log(n)) space
// Worst case: O(n) time and O(n) space
// where n = number of nodes in the tree

int minDiffHelper(Node *root, int target, int closest);

int minDiff(Node *root, int target) {
    return minDiffHelper(root, K, root->data);
}

int minDiffHelper(Node *root, int target, int closest) {
    if (abs(target-closest) > abs(target-root->data)) {
        closest = root->data;
    }
    if (root->left != nullptr && root->data > target) {
        return minDiffHelper(root->left, target, closest);
    } else if (root->right != nullptr && root->data < target) {
        return minDiffHelper(root->right, target, closest);
    } else {
        return closest;
    }
}

Iterative approach also performs on average at O(log(n)) time and doesn't add any recursive calls to the call stack so we only consume constant O(1) space rather than O(log(n) space which we see when we recursively add to the call stack.

Both of the algorithms (recursive and iterative) at worst case have O(n) space or can be written as O(d) space where d = depth of the tree.

// On average: O(log(n)) time and O(1) space
// Worst case: O(n) time and O(n) space
// where n = number of nodes in the tree

int minDiffHelper(Node *root, int target, int closest);

int minDiff(Node *root, int target) {
    return minDiffHelper(root, target, root->data);
}

int minDiffHelper(Node *root, int target, int closest) {
    Node *current = root;
    while (current != nullptr) {
        if (abs(target-closest) > abs(target-current->data)) {
            closest = current->data;
        }
        if (current->left != nullptr && current->data > target) {
            current = current->left;
        } else if (current->right != nullptr && current->data < target) {
            current = current->right;
        } else break;
    }
    return closest;
}

GeeksForGeeks has a nice practice problem for testing your understanding. It asks for a slightly different solution (returning the distance of the closest element to a given target) but that can easily be handled by returning the distance abs(target-closest) instead of simply returning the node value closest.

Tanner Dolby
  • 4,253
  • 3
  • 10
  • 21