1

I'm playing with Red-Black Tree with std::unique_ptr but it doesn't work.

My node definition:

enum class Color {
    Red,
    Black
};

template <typename T>
struct Node {
    T key;
    Color color;
    std::unique_ptr<Node<T>> left;
    std::unique_ptr<Node<T>> right;
    Node<T>* parent;

    Node(const T& key) : key {key}, parent {nullptr}, color {Color::Red} {}
};

I choose std::unique_ptr because std::shared_ptr is expensive and the parent owns its left and right child. Trivially, parent should be a raw pointer.

However, the logic behind Insert function breaks my basic design:

Here are my tree rotate functions. They accept rvalue reference of std::unique_ptr because it actually transfers ownership.

    void LeftRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->right);
        auto yl = y->left.get();
        x->right = std::move(y->left);
        if (yl) {
            yl->parent = x.get();
        }
        y->parent = x->parent;
        auto py = y.get();
        if (!x->parent) {
            root = std::move(y);
        } else if (x == x->parent->left) {
            x->parent->left = std::move(y);
        } else {
            x->parent->right = std::move(y);
        }
        x->parent = py;
        py->left = std::move(x);
    }

    void RightRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->left);
        auto yr = y->right.get();
        x->left = std::move(y->right);
        if (yr) {
            yr->parent = x.get();
        }
        y->parent = x->parent;
        auto py = y.get();
        if (!x->parent) {
            root = std::move(y);
        } else if (x == x->parent->left) {
            x->parent->left = std::move(y);
        } else {
            x->parent->right = std::move(y);
        }
        x->parent = py;
        py->right = std::move(x);
    }

Here is my Insert function:

public:
    void Insert(const T& key) {
        auto z = std::make_unique<Node<T>>(key);
        Insert(std::move(z));
    }

private:
    void Insert(std::unique_ptr<Node<T>> z) {
        Node<T>* y = nullptr;
        Node<T>* x = root.get();
        while (x) {
            y = x;
            if (z->key < x->key) {
                x = x->left.get();
            } else {
                x = x->right.get();
            }
        }
        z->parent = y;
        if (!y) {
            root = std::move(z);
            InsertFixup(std::move(root));
        } else if (z->key < y->key) {
            y->left = std::move(z);
            InsertFixup(std::move(y->left));
        } else {
            y->right = std::move(z);
            InsertFixup(std::move(y->right));
        }
    }

    void InsertFixup(std::unique_ptr<Node<T>>&& z) {
        auto zp = z->parent;
        while (zp && zp->color == Color::Red) {
            auto zpp = zp->parent;
            if (zp == zpp->left.get()) {
                auto y = zpp->right.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->right) {
                        z = std::unique_ptr<Node<T>>(zp);
                        auto pz = z.get();
                        LeftRotate(std::move(z));
                        zp = pz->parent;
                        zpp = zp->parent;
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
                    RightRotate(std::move(pzpp)); // error
                }
            } else {
                auto y = zpp->left.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->left) {
                        z = std::unique_ptr<Node<T>>(zp);
                        auto pz = z.get();
                        RightRotate(std::move(z));
                        zp = pz->parent;
                        zpp = zp->parent;
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
                    LeftRotate(std::move(pzpp)); // error
                }
            }
        }
        root->color = Color::Black;
    }

These following lines in InsertFixup are buggy:

auto pzpp = std::unique_ptr<Node<T>>(zpp); // error
LeftRotate(std::move(pzpp)); // error

What I want to do is rotate the tree around the grandma of the node z.

However, the problem is that it is impossible to get the std::unique_ptr that owns the grandma node (which is required to pass to LeftRotate function), because the parent link of my node implementation gives a raw pointer. Of course, I can track down from the root to get, but doing so will break the logarithmic time complexity of the inserting operation of RB-Tree, making it useless.

Should I use std::shared_ptr instead? Is there any way to implement RB-tree with std::unique_ptr implementation?

frozenca
  • 839
  • 4
  • 14
  • I doubt the extra cost of a `shared_ptr` will show in the profile. But it might simplify the code greatly. Also, you could consider `weak_ptr` for parents. – Jeffrey Jul 16 '20 at 13:50
  • @Jeffrey Yes, if I choose to use ```std::shared_ptr``` for ```left, right``` then ```parent``` should be ```std::weak_ptr```. Actually, I've already implemented it once using ```std::shared_ptr```. But I've been told that C++ community discourages usage of ```std::shared_ptr``` unless pointers are shared across multiple threads with indeterminate order. – frozenca Jul 16 '20 at 13:52
  • Can't you walk up your raw pointers until you hit the node which owns the node of interesst? I.e the node which actually holds the `std::unique_ptr` – Sebastian Hoffmann Jul 16 '20 at 13:52
  • 1
    Yes `unique_ptr` is the better choice here as the data-structure is the clear owner of the data. Its basically an implementation detail. There's a clear hierachy and responsibility and no internal data is ever shared, thus `shared_ptr` is wrong here – Sebastian Hoffmann Jul 16 '20 at 13:54
  • @SebastianHoffmann I can climb up, but I need ```std::unique_ptr``` to rotate the tree. The problem is that I could get ```std::unique_ptr``` instance by only going down from the root. – frozenca Jul 16 '20 at 13:56
  • Code changes over time. unique_ptr overly restrict you, unless you have a solid reason to prefer them. In the future, your RB tree might need to share the pointers with some other module. Maybe an analyzer, or a UI, some serialization code, or whatever. When that day comes, if you have shared_ptr, it will be easy. If you have unique_ptr, you'll have to rewrite it all. But yeah, "this discussion has been moved to chat" – Jeffrey Jul 16 '20 at 13:58
  • 1
    @frozenca No thats not true. Assuming that `zpp` is the left-child: `std::unique_ptr zpp = std::move(zp->parent->parent->left)`. In the real code you will have to check which one it is. – Sebastian Hoffmann Jul 16 '20 at 13:58
  • @SebastianHoffmann Oh, sure. I can access great-grandma and access grandma via checking its left or right child. Extremely tedious, but worth to try it. Thanks for your help. – frozenca Jul 16 '20 at 14:00
  • @frozenca You are welcome. Remember to check that `zp->parent->parent` exists as well, i.e that `zpp` is not already the root – Sebastian Hoffmann Jul 16 '20 at 14:02

1 Answers1

1

Now I made a correct implementation. It works nicely with the random order insertion/deletion test.

Complete code (include testing):

#include <cassert>
#include <iostream>
#include <memory>
#include <utility>
#include <numeric>
#include <vector>
#include <random>

std::mt19937 gen(std::random_device{}());

enum class Color {
    Red,
    Black
};

template <typename T>
struct Node {
    T key;
    Color color;
    std::unique_ptr<Node<T>> left;
    std::unique_ptr<Node<T>> right;
    Node<T>* parent;

    Node(const T& key) : key {key}, parent {nullptr}, color {Color::Red} {}
};

template <typename T>
struct RBTree {
public:
    std::unique_ptr<Node<T>> root;

private:
    void LeftRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->right);
        x->right = std::move(y->left);
        if (x->right) {
            x->right->parent = x.get();
        }
        y->parent = x->parent;
        auto xp = x->parent;
        if (!xp) {
            auto px = x.release();
            root = std::move(y);
            root->left = std::unique_ptr<Node<T>>(px);
            root->left->parent = root.get();
        } else if (x == xp->left) {
            auto px = x.release();
            xp->left = std::move(y);
            xp->left->left = std::unique_ptr<Node<T>>(px);
            xp->left->left->parent = xp->left.get();
        } else {
            auto px = x.release();
            xp->right = std::move(y);
            xp->right->left = std::unique_ptr<Node<T>>(px);
            xp->right->left->parent = xp->right.get();
        }
    }

    void RightRotate(std::unique_ptr<Node<T>>&& x) {
        auto y = std::move(x->left);
        x->left = std::move(y->right);
        if (x->left) {
            x->left->parent = x.get();
        }
        y->parent = x->parent;
        auto xp = x->parent;
        if (!xp) {
            auto px = x.release();
            root = std::move(y);
            root->right = std::unique_ptr<Node<T>>(px);
            root->right->parent = root.get();
        } else if (x == xp->left) {
            auto px = x.release();
            xp->left = std::move(y);
            xp->left->right = std::unique_ptr<Node<T>>(px);
            xp->left->right->parent = xp->left.get();
        } else {
            auto px = x.release();
            xp->right = std::move(y);
            xp->right->right = std::unique_ptr<Node<T>>(px);
            xp->right->right->parent = xp->right.get();
        }
    }

public:
    Node<T>* Search(const T& key) {
        return Search(root.get(), key);
    }

    void Insert(const T& key) {
        auto z = std::make_unique<Node<T>>(key);
        Insert(std::move(z));
    }

    void Delete(const T& key) {
        auto z = Search(key);
        Delete(z);
    }

private:
    Node<T>* Search(Node<T>* x, const T& key) {
        if (!x || x->key == key) {
            return x;
        }
        if (key < x->key) {
            return Search(x->left.get(), key);
        } else {
            return Search(x->right.get(), key);
        }
    }

    void Insert(std::unique_ptr<Node<T>> z) {
        Node<T>* y = nullptr;
        Node<T>* x = root.get();
        while (x) {
            y = x;
            if (z->key < x->key) {
                x = x->left.get();
            } else {
                x = x->right.get();
            }
        }
        z->parent = y;
        if (!y) {
            root = std::move(z);
            InsertFixup(std::move(root));
        } else if (z->key < y->key) {
            y->left = std::move(z);
            InsertFixup(std::move(y->left));
        } else {
            y->right = std::move(z);
            InsertFixup(std::move(y->right));
        }
    }

    void InsertFixup(std::unique_ptr<Node<T>>&& z) {
        auto zp = z->parent;
        while (zp && zp->color == Color::Red) {
            auto zpp = zp->parent;
            if (zp == zpp->left.get()) {
                auto y = zpp->right.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->right) {
                        LeftRotate(std::move(zpp->left));
                        zp = zpp->left.get();
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto zppp = zpp->parent;
                    if (!zppp) {
                        RightRotate(std::move(root));
                    } else if (zpp == zppp->left.get()) {
                        RightRotate(std::move(zppp->left));
                    } else {
                        RightRotate(std::move(zppp->right));
                    }
                }
            } else {
                auto y = zpp->left.get();
                if (y && y->color == Color::Red) {
                    zp->color = Color::Black;
                    y->color = Color::Black;
                    zpp->color = Color::Red;
                    zp = zpp->parent;
                } else {
                    if (z == zp->left) {
                        RightRotate(std::move(zpp->right));
                        zp = zpp->right.get();
                    }
                    zp->color = Color::Black;
                    zpp->color = Color::Red;
                    auto zppp = zpp->parent;
                    if (!zppp) {
                        LeftRotate(std::move(root));
                    } else if (zpp == zppp->left.get()) {
                        LeftRotate(std::move(zppp->left));
                    } else {
                        LeftRotate(std::move(zppp->right));
                    }
                }
            }
        }
        root->color = Color::Black;
    }

    Node<T>* Transplant(Node<T>* u, std::unique_ptr<Node<T>>&& v) {
        if (v) {
            v->parent = u->parent;
        }
        Node<T>* w = nullptr;
        if (!u->parent) {
            w = root.release();
            root = std::move(v);
        } else if (u == u->parent->left.get()) {
            w = u->parent->left.release();
            u->parent->left = std::move(v);
        } else {
            w = u->parent->right.release();
            u->parent->right = std::move(v);
        }
        return w;
    }

    Node<T>* Minimum(Node<T>* x) {
        if (!x) {
            return x;
        }
        while (x->left) {
            x = x->left.get();
        }
        return x;
    }

    void Delete(Node<T>* z) {
        if (!z) {
            return;
        }
        Color orig_color = z->color;
        Node<T>* x = nullptr;
        Node<T>* xp = nullptr;
        if (!z->left) {
            x = z->right.get();
            xp = z->parent;
            auto pz = Transplant(z, std::move(z->right));
            auto upz = std::unique_ptr<Node<T>>(pz);
        } else if (!z->right) {
            x = z->left.get();
            xp = z->parent;
            auto pz = Transplant(z, std::move(z->left));
            auto upz = std::unique_ptr<Node<T>>(pz);
        } else {
            auto y = Minimum(z->right.get());
            orig_color = y->color;
            x = y->right.get();
            xp = y;
            if (y->parent == z) {
                if (x) {
                    x->parent = y;
                }
                auto pz = Transplant(z, std::move(z->right));
                y->left = std::move(pz->left);
                y->left->parent = y;
                y->color = pz->color;
                auto upz = std::unique_ptr<Node<T>>(pz);
            } else {
                xp = y->parent;
                auto py = Transplant(y, std::move(y->right));
                py->right = std::move(z->right);
                py->right->parent = py;
                auto upy = std::unique_ptr<Node<T>>(py);
                auto pz = Transplant(z, std::move(upy));
                py->left = std::move(pz->left);
                py->left->parent = py;
                py->color = pz->color;
                auto upz = std::unique_ptr<Node<T>>(pz);
            }
        }
        if (orig_color == Color::Black) {
            DeleteFixup(x, xp);
        }
    }

    void DeleteFixup(Node<T>* x, Node<T>* xp) {
        while (x != root.get() && (!x || x->color == Color::Black)) {
            if (x == xp->left.get()) {
                Node<T>* w = xp->right.get();
                if (w && w->color == Color::Red) {
                    w->color = Color::Black;
                    xp->color = Color::Red;
                    auto xpp = xp->parent;
                    if (!xpp) {
                        LeftRotate(std::move(root));
                    } else if (xp == xpp->left.get()) {
                        LeftRotate(std::move(xpp->left));
                    } else {
                        LeftRotate(std::move(xpp->right));
                    }
                    w = xp->right.get();
                }
                if (w && (!w->left || w->left->color == Color::Black)
                && (!w->right || w->right->color == Color::Black)) {
                    w->color = Color::Red;
                    x = xp;
                    xp = xp->parent;
                } else if (w) {
                    if (!w->right || w->right->color == Color::Black) {
                        w->left->color = Color::Black;
                        w->color = Color::Red;
                        auto wp = w->parent;
                        if (!wp) {
                            RightRotate(std::move(root));
                        } else if (w == wp->left.get()) {
                            RightRotate(std::move(wp->left));
                        } else {
                            RightRotate(std::move(wp->right));
                        }
                        w = xp->right.get();
                    }
                    w->color = xp->color;
                    xp->color = Color::Black;
                    w->right->color = Color::Black;
                    auto xpp = xp->parent;
                    if (!xpp) {
                        LeftRotate(std::move(root));
                    } else if (xp == xpp->left.get()) {
                        LeftRotate(std::move(xpp->left));
                    } else {
                        LeftRotate(std::move(xpp->right));
                    }
                    x = root.get();
                } else {
                    x = root.get();
                }
            } else {
                Node<T>* w = xp->left.get();
                if (w && w->color == Color::Red) {
                    w->color = Color::Black;
                    xp->color = Color::Red;
                    auto xpp = xp->parent;
                    if (!xpp) {
                        RightRotate(std::move(root));
                    } else if (xp == xpp->left.get()) {
                        RightRotate(std::move(xpp->left));
                    } else {
                        RightRotate(std::move(xpp->right));
                    }
                    w = xp->left.get();
                }
                if (w && (!w->left || w->left->color == Color::Black)
                    && (!w->right || w->right->color == Color::Black)) {
                    w->color = Color::Red;
                    x = xp;
                    xp = xp->parent;
                } else if (w) {
                    if (!w->left || w->left->color == Color::Black) {
                        w->right->color = Color::Black;
                        w->color = Color::Red;
                        auto wp = w->parent;
                        if (!wp) {
                            LeftRotate(std::move(root));
                        } else if (w == wp->left.get()) {
                            LeftRotate(std::move(wp->left));
                        } else {
                            LeftRotate(std::move(wp->right));
                        }
                        w = xp->left.get();
                    }
                    w->color = xp->color;
                    xp->color = Color::Black;
                    w->left->color = Color::Black;
                    auto xpp = xp->parent;
                    if (!xpp) {
                        RightRotate(std::move(root));
                    } else if (xp == xpp->left.get()) {
                        RightRotate(std::move(xpp->left));
                    } else {
                        RightRotate(std::move(xpp->right));
                    }
                    x = root.get();
                } else {
                    x = root.get();
                }
            }
        }
        if (x) {
            x->color = Color::Black;
        }
    }

};

template <typename T>
std::ostream& operator<<(std::ostream& os, Node<T>* node) {
    if (node) {
        os << node->left.get();
        os << node->key;
        if (node->color == Color::Black) {
            os << "● ";
        } else {
            os << "○ ";
        }
        os << node->right.get();
    }
    return os;
}

template <typename T>
std::ostream& operator<<(std::ostream& os, const RBTree<T>& tree) {
    os << tree.root.get();
    return os;
}

int main() {
    constexpr size_t SIZE = 100;
    std::vector<int> v (SIZE);
    std::iota(v.begin(), v.end(), 1);
    std::shuffle(v.begin(), v.end(), gen);
    RBTree<int> rbtree;
    for (auto n : v) {
        rbtree.Insert(n);
    }
    std::cout << '\n';
    std::cout << rbtree << '\n';
    std::shuffle(v.begin(), v.end(), gen);
    for (auto n : v) {
        rbtree.Delete(n);
        std::cout << rbtree << '\n';
    }

}
frozenca
  • 839
  • 4
  • 14