You can use this custom AVL tree as a class implementation with some useful improvements.
- Count nodes greater than a value.
- Count the sum of the values of the nodes greater than a value.
- Duplicate keys.
- Print the structure of the tree on console
All with time complexity of O(log N).
Based on https://www.geeksforgeeks.org/avl-with-duplicate-keys/
#include <bits/stdc++.h>
using namespace std;
class AVL {
public:
class AVLNode {
public:
int value;
int height = 1;
int count = 1;
int desc = 0;
int descSum = 0;
AVLNode* right = NULL;
AVLNode* left = NULL;
AVLNode (int value) {
this->value = value;
}
int size() { return desc+count; }
int sum() { return descSum+value*count; }
void rotateLeft() {
if (right == NULL) return;
AVLNode* N = new AVLNode(0); *N = *this;
AVLNode* R = right;
AVLNode* LR = right->left;
R->left = N;
*this = *right;
N->right = LR;
N->height = 1+max(N->left != NULL ? N->left->height : 0,
N->right != NULL ? N->right->height : 0);
height = 1+max(left != NULL ? left->height : 0,
right != NULL ? right->height : 0);
N->desc = (N->left != NULL ? N->left->desc+N->left->count : 0) +
(N->right != NULL ? N->right->desc+N->right->count : 0);
N->descSum = (N->left != NULL ? N->left->descSum+N->left->value*N->left->count : 0) +
(N->right != NULL ? N->right->descSum+N->right->value*N->right->count : 0);
desc = (left != NULL ? left->desc+left->count : 0) +
(right != NULL ? right->desc+right->count : 0);
descSum = (left != NULL ? left->descSum+left->value*left->count : 0) +
(right != NULL ? right->descSum+right->value*right->count : 0);
delete(R);
}
void rotateRight() {
if (left == NULL) return;
AVLNode* N = new AVLNode(0); *N = *this;
AVLNode* L = left;
AVLNode* RL = left->right;
L->right = N;
*this = *left;
N->left = RL;
N->height = 1+max(N->left != NULL ? N->left->height : 0,
N->right != NULL ? N->right->height : 0);
height = 1+max(left != NULL ? left->height : 0,
right != NULL ? right->height : 0);
N->desc = (N->left != NULL ? N->left->desc+N->left->count : 0) +
(N->right != NULL ? N->right->desc+N->right->count : 0);
N->descSum = (N->left != NULL ? N->left->descSum+N->left->value*N->left->count : 0) +
(N->right != NULL ? N->right->descSum+N->right->value*N->right->count : 0);
desc = (left != NULL ? left->desc+left->count : 0) +
(right != NULL ? right->desc+right->count : 0);
descSum = (left != NULL ? left->descSum+left->value*left->count : 0) +
(right != NULL ? right->descSum+right->value*right->count : 0);
delete(L);
}
int getBalance() {
return (left != NULL ? left->height : 0) - (right != NULL ? right->height : 0);
}
bool insert(int value) {
if (value == this->value) {
count++;
return false;
}
bool newNodeAdded = true;
if(value > this->value) {
if (right != NULL)
newNodeAdded = right->insert(value);
else
right = new AVLNode(value);
}else { // if (value < this->value)
if (left != NULL)
newNodeAdded = left->insert(value);
else
left = new AVLNode(value);
}
desc++;
descSum += value;
height = max(
left != NULL ? left->height : 0,
right != NULL ? right->height : 0
)+1;
if (newNodeAdded) {
int balance = (left != NULL ? left->height : 0) - (right != NULL ? right->height : 0);
if (balance > 1 && value < left->value) { // Left Left Case
rotateRight();
}else if (balance > 1 && value > left->value) { // Left Right Case
left->rotateLeft();
rotateRight();
}else if (balance < -1 && value > right->value) { // Right Right Case
rotateLeft();
}else if (balance < -1 && value < right->value) { // Right Left Case
right->rotateRight();
rotateLeft();
}
}
return newNodeAdded;
}
pair<bool, bool> remove(int value) {
if (value == this->value) {
count--;
if (count == 0) {
if (left == NULL && right == NULL) { // No child
delete(this);
}else if (left == NULL || right == NULL) { // One child
AVLNode* child = left ? left : right;
*this = *child;
delete(child);
} else { // Both children
AVLNode* rightChildLastLeftNode = right;
while (rightChildLastLeftNode->left != NULL) {
rightChildLastLeftNode->desc += left->desc+left->count; // We alse need to update desc and height
rightChildLastLeftNode->descSum += left->descSum+left->value*left->count;
rightChildLastLeftNode->height = max(
(rightChildLastLeftNode->left != NULL ? rightChildLastLeftNode->left->height : 0)+left->height,
rightChildLastLeftNode->right != NULL ? rightChildLastLeftNode->right->height : 0
)+1;
rightChildLastLeftNode = rightChildLastLeftNode->left;
}
rightChildLastLeftNode->desc += left->desc+left->count; // We alse need to update desc and height
rightChildLastLeftNode->descSum += left->descSum+left->value*left->count;
rightChildLastLeftNode->height = max(
(rightChildLastLeftNode->left != NULL ? rightChildLastLeftNode->left->height : 0)+left->height,
rightChildLastLeftNode->right != NULL ? rightChildLastLeftNode->right->height : 0
)+1;
rightChildLastLeftNode->left = left;
AVLNode* previousRight = right;
*this = *right;
delete(previousRight);
}
}else
return {true, false};
} else {
pair<bool, bool> nodeRemoved;
if(value > this->value && right != NULL) {
bool remove = right->value == value && right->count == 1 && right->desc == 0;
nodeRemoved = right->remove(value);
if (remove)
right = NULL;
}else if (left != NULL) { // if (value < this->value)
bool remove = left->value == value && left->count == 1 && left->desc == 0;
nodeRemoved = left->remove(value);
if (remove)
left = NULL;
}else
return {false, false};
if (!nodeRemoved.first)
return nodeRemoved;
desc--;
descSum -= value;
height = max(
left != NULL ? left->height : 0,
right != NULL ? right->height : 0
)+1;
if (nodeRemoved.second) {
int balance = (left != NULL ? left->height : 0) - (right != NULL ? right->height : 0);
int leftBalance = left != NULL ? (left->left != NULL ? left->left->height : 0) - (left->right != NULL ? left->right->height : 0) : 0;
int rightBalance = right != NULL ? (right->left != NULL ? right->left->height : 0) - (right->right != NULL ? right->right->height : 0) : 0;
if (balance > 1 && leftBalance >= 0) { // Left Left Case
rotateRight();
}else if (balance > 1 && leftBalance < 0) { // Left Right Case
left->rotateLeft();
rotateRight();
}else if (balance < -1 && rightBalance <= 0) { // Right Right Case
rotateLeft();
}else if (balance < -1 && rightBalance > 0) { // Right Left Case
right->rotateRight();
rotateLeft();
}
}
}
return {true, true};
}
int countEqual(int value) {
if (value == this->value)
return count;
else if (value > this->value)
return right != NULL ? right->countEqual(value) : 0;
else
return left != NULL ? left->countEqual(value) : 0;
}
int countGreater(int value) {
if (value == this->value)
return right != NULL ? right->desc+right->count : 0;
else if (value > this->value)
return right != NULL ? right->countGreater(value) : 0;
else
return (right != NULL ? right->desc+right->count : 0)+count+(left != NULL ? left->countGreater(value) : 0);
}
int countGreaterSum(int value) {
if (value == this->value)
return right != NULL ? right->descSum+right->value*right->count : 0;
else if (value > this->value)
return right != NULL ? right->countGreaterSum(value) : 0;
else
return (right != NULL ? right->descSum+right->value*right->count : 0)+this->value*count+(left != NULL ? left->countGreaterSum(value) : 0);
}
};
AVLNode* node = NULL;
int height() { return node != NULL ? node->height : 0; }
int size() { return node != NULL ? node->size() : 0; }
// Sum of all elements of the tree
int sum() { return node != NULL ? node->sum() : 0; }
int getHeight() { return node != NULL ? node->height : 0; }
void insert(int value) {
if (node != NULL)
node->insert(value);
else
node = new AVLNode(value);
}
void remove(int value) {
if (node != NULL) {
bool remove = node->value == value && node->count == 1 && node->desc == 0;
node->remove(value);
if (remove)
node = NULL;
}
}
// Count of elements with value $value
int countEqual(int value) {
return node != NULL ? node->countEqual(value) : 0;
}
// Count elements greater than value
int countGreater(int value) {
return node != NULL ? node->countGreater(value) : 0;
}
// Count elements greater or equal than value
int countGreaterEqual(int value) {
return countGreater(value-1);
}
// Count elements lower than value
int countLower(int value) {
return size()-countGreater(value-1);
}
// Count elements lower or equal than value
int countLowerEqual(int value) {
return countLower(value+1);
}
// Count the sum of all elements greater than value
int countGreaterSum(int value) {
return node != NULL ? node->countGreaterSum(value) : 0;
}
// Count the sum of all elements greater of equal than value
int countGreaterEqualSum(int value) {
return countGreaterSum(value-1);
}
// Count the sum of all elements lower than value
int countLowerSum(int value) {
return sum()-countGreaterSum(value-1);
}
// Count the sum of all elements lower or equal than value
int countLowerEqualSum(int value) {
return countLowerSum(value+1);
}
};
void printAVLTree(AVL avl) {
queue<AVL::AVLNode*> queue;
queue.push(avl.node);
AVL::AVLNode* node = avl.node;
while (true)
if (node->right != NULL)
node = node->right;
else
break;
int numDigits = to_string(node->value).size();
int height = max(avl.getHeight(), 1); int currentHeight = height;
vector<int> spaces = {0, max(numDigits, 1)};
for (int i = 2; i < height+2; i++)
spaces.push_back(2*spaces[i-1]+spaces[1]);
string connectionsString = "";
for(int i = 1; ; i++) {
AVL::AVLNode* node = queue.front();
queue.pop();
if ((int)log2(i)-log2(i) == 0) {
cout << "\n";
if (currentHeight == 0)
break;
for (int j = 0; j < spaces[currentHeight-1]; j++) {
if (j < spaces[currentHeight-1]*4/5)
connectionsString += " ";
cout << " ";
}
currentHeight--;
}
string value = node != NULL ? to_string(node->value) : ".";
for (int j = 0; j < -(int)value.size()+numDigits; j++)
cout << " ";
cout << value;
for (int j = 0; j < spaces[currentHeight+1]; j++)
cout << " ";
if (node != NULL) {
queue.push(node->left);
queue.push(node->right);
}else {
queue.push(NULL);
queue.push(NULL);
}
}
cout << "\n";
}
void printAVLNodes(AVL::AVLNode* node) {
if (node != NULL) {
cout << "{v: " << node->value << ", c: " << node->count << "} ";
printAVLNodes(node->left);
printAVLNodes(node->right);
}else
cout << "{.} ";
}
int main() {
AVL avl;
// Insert
avl.insert(1);
avl.insert(2);
avl.insert(3);
avl.insert(5);
avl.insert(4);
avl.insert(0);
// Insert repeated elements
avl.insert(2);
avl.insert(1);
avl.insert(1);
// Remove elements
avl.remove(1);
avl.remove(5);
cout << "There are " << avl.countEqual(1) << " elements with value 1" << "\n";
cout << "There are " << avl.countGreater(1) << " elements greater than 1" << "\n";
cout << "The sum of all elements greater than 1 is " << avl.countGreaterSum(1) << "\n";
// Print tree
printAVLTree(avl);
printAVLNodes(avl.node);
return 0;
}