-1

So I'm trying to implement AVL-tree as a class. And my insertion function doesn't work as expected. Basically every time I try to insert an element it just iterates over if (x == nullptr) return new node(key); line and that's it. So it creates node somewhere, but not actually adding it to the tree. Is something wrong with my insertion function or with the way I imlement the class itself?

#include <iostream>

using namespace std;

template<class T>
class AVL {
private:
    struct node {
        T key;
        node *left;
        node *right;
        int height{};
        explicit node(T key) {
            this->key = key;
            this->left = nullptr;
            this->right = nullptr;
            this->height = 1;
        }
    };

public:
    node *root;

    AVL() {
        this->root = nullptr;
    }

    int max (int x, int y) {
        return (x > y)? x : y;
    }

    int balanceFactor (node *x){
        if (x == nullptr) return 0;
        return (x->left->height) - (x->right->height);
    }

    node *leftRotate (node *x) {
        node *y = x->right;

        y->left = x;
        x->right = y->left;

        x->height = max(x->left->height, x->right->height + 1);
        y->height = max (y->left->height, y->right->height + 1);

        return y;
    }

    node *rightRotate (node *y) {
        node *x = y->right;

        x->right = y;
        y->left = x->right;

        y->height = max(y->left->height, y->right->height + 1);
        x->height = max(x->left->height, x->right->height + 1);

        return x;
    }

    node *balance (node *x) {
        x->height = max(x->left->height, x->right->height + 1);
        if (balanceFactor(x) > 1) {
            if (balanceFactor(x->right) < 0)
                x->right = rightRotate(x->right);
            return leftRotate (x);
        }

        if (balanceFactor(x) < -1) {
            if (balanceFactor(x->left) > 0)
                x->left = rightRotate(x->left);
            return rightRotate(x);
        }
    }

    
    node *insert (node *x, int key){
        if (x == nullptr) return new node(key);
        if (key < x->key)
            x->left = insert(x->left, key);
        else if (key > x->key)
            x->right = insert(x->right, key);
        else return x;
        return balance(x);
    }


    void inorder(node *leaf) {
        if (leaf != nullptr) {
            inorder(leaf->left);
            std::cout << leaf->key << " ";
            inorder(leaf->right);
        }
    }

    void preorder(node *leaf) {
        if (leaf != nullptr) {
            std::cout << leaf->key << " ";
            preorder(leaf->left);
            preorder(leaf->right);
        }
    }

    void postorder(node *leaf) {
        if (leaf != nullptr) {
            postorder(leaf->left);
            postorder(leaf->right);
            std::cout << leaf->key << " ";
        }
    }

};

int main() {
    system("chcp 65001"); //UTF-8 for Windows console

    int size;
    cin >> size;
    AVL<int> avl;
    for (int i = 0; i < size; i++) {
        int x;
        cin >> x;
        avl.insert(avl.root, x);
    }

    avl.inorder(avl.root);
    avl.preorder(avl.root);
    avl.postorder(avl.root);

    return 0;
}
primadonna
  • 142
  • 4
  • 12
  • 2
    `insert` returns the new node, but since you ignore that result in `main`, you can do that forever with no change to the world. – 500 - Internal Server Error Apr 01 '21 at 17:41
  • Off-topic (maybe), but why does the caller to `insert` need to know anything about where the node was inserted? In other words, why is the return value of `insert` so important for your code to work? The user of `avl` inserts an item, and `avl` figures out how to insert the item and balance the tree (if necessary) -- the caller shouldn't need to save the return value for anything. Maybe that's why you thought your code should work, and it should work, **if** you didn't put the responsibility of inserting a node on the calling code having to use that return value. – PaulMcKenzie Apr 01 '21 at 18:34

1 Answers1

0

You can use this custom AVL tree as a class implementation with some useful improvements.

  • Count nodes greater than a value.
  • Count the sum of the values of the nodes greater than a value.
  • Duplicate keys.
  • Print the structure of the tree on console

All with time complexity of O(log N).

Based on https://www.geeksforgeeks.org/avl-with-duplicate-keys/

#include <bits/stdc++.h>

using namespace std;

class AVL {
    public:
    class AVLNode {
        public:
        int value;
        int height = 1;
        int count = 1;
        int desc = 0;
        int descSum = 0;

        AVLNode* right = NULL;
        AVLNode* left = NULL;
        AVLNode (int value) {
            this->value = value;
        }

        int size() { return desc+count; }
        int sum() { return descSum+value*count; }

        void rotateLeft() {
            if (right == NULL) return;

            AVLNode* N = new AVLNode(0); *N = *this;
            AVLNode* R = right;
            AVLNode* LR = right->left;

            R->left = N;
            *this = *right;
            N->right = LR;

            N->height = 1+max(N->left != NULL ? N->left->height : 0, 
                            N->right != NULL ? N->right->height : 0);
            height = 1+max(left != NULL ? left->height : 0, 
                            right != NULL ? right->height : 0);

            N->desc = (N->left != NULL ? N->left->desc+N->left->count : 0) +
                        (N->right != NULL ? N->right->desc+N->right->count : 0);
            N->descSum = (N->left != NULL ? N->left->descSum+N->left->value*N->left->count : 0) +
                        (N->right != NULL ? N->right->descSum+N->right->value*N->right->count : 0);
            desc = (left != NULL ? left->desc+left->count : 0) +
                        (right != NULL ? right->desc+right->count : 0);
            descSum = (left != NULL ? left->descSum+left->value*left->count : 0) +
                        (right != NULL ? right->descSum+right->value*right->count : 0);
            
            delete(R);
        }

        void rotateRight() {
            if (left == NULL) return;

            AVLNode* N = new AVLNode(0); *N = *this;
            AVLNode* L = left;
            AVLNode* RL = left->right;

            L->right = N;
            *this = *left;
            N->left = RL;

            N->height = 1+max(N->left != NULL ? N->left->height : 0, 
                            N->right != NULL ? N->right->height : 0);
            height = 1+max(left != NULL ? left->height : 0, 
                            right != NULL ? right->height : 0);

            N->desc = (N->left != NULL ? N->left->desc+N->left->count : 0) +
                        (N->right != NULL ? N->right->desc+N->right->count : 0);
            N->descSum = (N->left != NULL ? N->left->descSum+N->left->value*N->left->count : 0) +
                        (N->right != NULL ? N->right->descSum+N->right->value*N->right->count : 0);
            desc = (left != NULL ? left->desc+left->count : 0) +
                        (right != NULL ? right->desc+right->count : 0);
            descSum = (left != NULL ? left->descSum+left->value*left->count : 0) +
                        (right != NULL ? right->descSum+right->value*right->count : 0);
            
            delete(L);
        }

        int getBalance() { 
            return (left != NULL ? left->height : 0) - (right != NULL ? right->height : 0); 
        }

        bool insert(int value) {
            if (value == this->value) {
                count++;
                return false;
            }

            bool newNodeAdded = true;
            if(value > this->value) {
                if (right != NULL)
                    newNodeAdded = right->insert(value);
                else
                    right = new AVLNode(value);
            }else { // if (value < this->value)
                if (left != NULL)
                    newNodeAdded = left->insert(value);
                else
                    left = new AVLNode(value);
            }
            desc++;
            descSum += value;
            height = max(
                left != NULL ? left->height : 0,
                right != NULL ? right->height : 0
            )+1;
            
            if (newNodeAdded) {
                int balance = (left != NULL ? left->height : 0) - (right != NULL ? right->height : 0); 
        
                if (balance > 1 && value < left->value) { // Left Left Case 
                    rotateRight();
                }else if (balance > 1 && value > left->value) { // Left Right Case 
                    left->rotateLeft();
                    rotateRight();
                }else if (balance < -1 && value > right->value) { // Right Right Case 
                    rotateLeft();
                }else if (balance < -1 && value < right->value) { // Right Left Case 
                    right->rotateRight();
                    rotateLeft();
                }
            }
            return newNodeAdded;
        }

        pair<bool, bool> remove(int value) {
            if (value == this->value) {
                count--;
                if (count == 0) {
                    if (left == NULL && right == NULL) { // No child 
                        delete(this);
                    }else if (left == NULL || right == NULL) { // One child
                        AVLNode* child = left ? left : right; 
                        *this = *child;
                        delete(child);
                    } else { // Both children
                        AVLNode* rightChildLastLeftNode = right;
                        while (rightChildLastLeftNode->left != NULL) {
                            rightChildLastLeftNode->desc += left->desc+left->count; // We alse need to update desc and height
                            rightChildLastLeftNode->descSum += left->descSum+left->value*left->count; 
                            rightChildLastLeftNode->height = max(
                                (rightChildLastLeftNode->left != NULL ? rightChildLastLeftNode->left->height : 0)+left->height,
                                rightChildLastLeftNode->right != NULL ? rightChildLastLeftNode->right->height : 0
                            )+1;


                            rightChildLastLeftNode = rightChildLastLeftNode->left;
                        }
                        rightChildLastLeftNode->desc += left->desc+left->count; // We alse need to update desc and height
                        rightChildLastLeftNode->descSum += left->descSum+left->value*left->count; 
                        rightChildLastLeftNode->height = max(
                            (rightChildLastLeftNode->left != NULL ? rightChildLastLeftNode->left->height : 0)+left->height,
                            rightChildLastLeftNode->right != NULL ? rightChildLastLeftNode->right->height : 0
                        )+1;
                        rightChildLastLeftNode->left = left;
                        
                        AVLNode* previousRight = right;
                        *this = *right;
                        delete(previousRight);
                    }
                }else 
                    return {true, false};
            } else {
                pair<bool, bool> nodeRemoved;
                if(value > this->value && right != NULL) {
                    bool remove = right->value == value && right->count == 1 && right->desc == 0;
                    nodeRemoved = right->remove(value);
                    if (remove)
                        right = NULL;
                }else if (left != NULL) { // if (value < this->value)
                    bool remove = left->value == value && left->count == 1 && left->desc == 0;
                    nodeRemoved = left->remove(value);
                    if (remove)
                        left = NULL;
                }else
                    return {false, false};

                if (!nodeRemoved.first)
                    return nodeRemoved;

                desc--;
                descSum -= value;
                height = max(
                    left != NULL ? left->height : 0,
                    right != NULL ? right->height : 0
                )+1;
                if (nodeRemoved.second) {
                    int balance = (left != NULL ? left->height : 0) - (right != NULL ? right->height : 0); 
                    int leftBalance = left != NULL ? (left->left != NULL ? left->left->height : 0) - (left->right != NULL ? left->right->height : 0) : 0;
                    int rightBalance = right != NULL ? (right->left != NULL ? right->left->height : 0) - (right->right != NULL ? right->right->height : 0) : 0;
                
                    if (balance > 1 && leftBalance >= 0) { // Left Left Case 
                        rotateRight();
                    }else if (balance > 1 && leftBalance < 0) { // Left Right Case 
                        left->rotateLeft();
                        rotateRight();
                    }else if (balance < -1 && rightBalance <= 0) { // Right Right Case 
                        rotateLeft();
                    }else if (balance < -1 && rightBalance > 0) { // Right Left Case 
                        right->rotateRight();
                        rotateLeft();
                    }
                }
            }
            return {true, true};
        }

        int countEqual(int value) {
            if (value == this->value)
                return count;
            else if (value > this->value)
                return right != NULL ? right->countEqual(value) : 0;
            else
                return left != NULL ? left->countEqual(value) : 0;
        }

        int countGreater(int value) {
            if (value == this->value)
                return right != NULL ? right->desc+right->count : 0;
            else if (value > this->value)
                return right != NULL ? right->countGreater(value) : 0;
            else
                 return (right != NULL ? right->desc+right->count : 0)+count+(left != NULL ? left->countGreater(value) : 0);
        }

        int countGreaterSum(int value) {
            if (value == this->value)
                return right != NULL ? right->descSum+right->value*right->count : 0;
            else if (value > this->value)
                return right != NULL ? right->countGreaterSum(value) : 0;
            else
                return (right != NULL ? right->descSum+right->value*right->count : 0)+this->value*count+(left != NULL ? left->countGreaterSum(value) : 0);
        }
    };

    AVLNode* node = NULL;

    int height() { return node != NULL ? node->height : 0; }

    int size() { return node != NULL ? node->size() : 0; }

    // Sum of all elements of the tree
    int sum() { return node != NULL ? node->sum() : 0; }

    int getHeight() { return node != NULL ? node->height : 0; }

    void insert(int value) {
        if (node != NULL)
            node->insert(value);
        else 
            node = new AVLNode(value);
    }

    void remove(int value) {
        if (node != NULL) {
            bool remove = node->value == value && node->count == 1 && node->desc == 0;
            node->remove(value);
            if (remove)
                node = NULL;
        }
    }

    // Count of elements with value $value
    int countEqual(int value) {
        return node != NULL ? node->countEqual(value) : 0;
    }



    // Count elements greater than value
    int countGreater(int value) {
        return node != NULL ? node->countGreater(value) : 0;
    }

    // Count elements greater or equal than value
    int countGreaterEqual(int value) {
        return countGreater(value-1);
    }

    // Count elements lower than value
    int countLower(int value) {
        return size()-countGreater(value-1);
    }

    // Count elements lower or equal than value
    int countLowerEqual(int value) {
        return countLower(value+1);
    }



    // Count the sum of all elements greater than value
    int countGreaterSum(int value) {
        return node != NULL ? node->countGreaterSum(value) : 0;
    }

    // Count the sum of all elements greater of equal than value
    int countGreaterEqualSum(int value) {
        return countGreaterSum(value-1);
    }

    // Count the sum of all elements lower than value
    int countLowerSum(int value) {
        return sum()-countGreaterSum(value-1);
    }

    // Count the sum of all elements lower or equal than value
    int countLowerEqualSum(int value) {
        return countLowerSum(value+1);
    }
};

void printAVLTree(AVL avl) {
    queue<AVL::AVLNode*> queue;
    queue.push(avl.node);

    AVL::AVLNode* node = avl.node;
    while (true)
        if (node->right != NULL)
            node = node->right;
        else
            break;
    int numDigits = to_string(node->value).size();

    int height = max(avl.getHeight(), 1); int currentHeight = height;

    vector<int> spaces = {0, max(numDigits, 1)}; 
    for (int i = 2; i < height+2; i++) 
        spaces.push_back(2*spaces[i-1]+spaces[1]);

    string connectionsString = "";
    for(int i = 1; ; i++) {
        AVL::AVLNode* node = queue.front();
        queue.pop();
        
        if ((int)log2(i)-log2(i) == 0) {
            cout << "\n";
            if (currentHeight == 0)
                break;
            for (int j = 0; j < spaces[currentHeight-1]; j++) {
                if (j < spaces[currentHeight-1]*4/5)
                    connectionsString += " ";
                cout << " ";
            }
            currentHeight--;
        }

        string value = node != NULL ? to_string(node->value) : ".";
        for (int j = 0; j < -(int)value.size()+numDigits; j++)
            cout << " ";
        cout << value;

        for (int j = 0; j < spaces[currentHeight+1]; j++) 
            cout << " ";

        if (node != NULL) {
            queue.push(node->left);
            queue.push(node->right);
        }else {
            queue.push(NULL);
            queue.push(NULL);
        }
    }
    cout << "\n";
}

void printAVLNodes(AVL::AVLNode* node) {
    if (node != NULL) {
        cout << "{v: " << node->value << ", c: " << node->count << "} ";
        printAVLNodes(node->left);
        printAVLNodes(node->right);
    }else
        cout << "{.} ";
}

int main() {
    AVL avl;

    // Insert
    avl.insert(1);
    avl.insert(2);
    avl.insert(3);
    avl.insert(5);
    avl.insert(4);
    avl.insert(0);

    // Insert repeated elements
    avl.insert(2);
    avl.insert(1);
    avl.insert(1);

    // Remove elements
    avl.remove(1);
    avl.remove(5);
    

    cout << "There are " << avl.countEqual(1) << " elements with value 1" << "\n";
    cout << "There are " << avl.countGreater(1) << " elements greater than 1" << "\n";
    cout << "The sum of all elements greater than 1 is " << avl.countGreaterSum(1) << "\n";

    // Print tree
    printAVLTree(avl);
    printAVLNodes(avl.node);

    return 0;
}