Let's assume that the node structure looks like this (Java).
class Node {
Node left;
Node right;
int key;
int value;
int tree_max;
}
The recurrence for tree_max
is
node.tree_max == max(node.value, node.left.tree_max, node.right.tree_max),
where by abuse of notation we omit node.left.tree_max
when node.left
is null and omit node.right.tree_max
when node.right
is null. Every time we write to a node, we may have to update all of its ancestors. I'm not going to write the pseudocode, because without a compiler I'll most likely get it wrong.
To find the max between keys k1
and k2
inclusive, we first locate the least common ancestor of those nodes.
Node lca = root;
while (lca != null) {
if (lca.key < k1) { lca = lca.left; }
else if (k2 < lca.key) { lca = lca.right; }
else { break; }
}
Now, if lca
is null, then the range is empty, and we should return minus infinity or throw an exception. Otherwise, we need to find the max over three ranges: k1
inclusive to lca
exclusive, lca
itself, and lca
exclusive to k2
inclusive. I'll give the code for k1
inclusive to lca
exclusive; the other two ranges are trivial and symmetric respectively. We move finger
down the tree as though we're searching for k1
, accumulating the maximum into left_max
.
int left_max = /* minus infinity */;
Node finger = lca.left;
while (finger != null) {
if (k1 <= finger.key) {
left_max = max(left_max, finger.value);
if (finger.right != null) { left_max = max(left_max, finger.right.tree_max); }
finger = finger.left;
} else { finger = finger.right; }
}