1

I am writing some code on Splay Tree nodes. Without being too technical, I want to implement one base tree and one derived tree that supports reversion of the left and right sub-trees. The current excerpt looks like this:

struct node {
  node *f, *c[2];
  int size;
  void push_down() {}
};

struct reversable_node : node {
  int r;
  void push_down() {
    if (r) {
      std::swap(c[0], c[1]);
      c[0]->r ^= 1, c[1]->r ^= 1, r = 0;
    }
  }
};

This obviously does not work, because c[0] are of type node and does not have member r. Still, I know that c[0] of node only points to node and c[0] of reversable_node only points to reversable_node. So I can do some cast:

      ((reversable_node *)c[0])->r ^= 1, ((reversable_node *)c[1])->r ^= 1, r = 0;

But this looks super clumsy. Is there a better way to do a self-reference pointer in the base class that also works in derived classes?

P.S. The whole code looks like this:

struct node {
  node *f, *c[2];
  int size;
  node() {
    f = c[0] = c[1] = nullptr;
    size = 1;
  }
  void push_down() {}
  void update() {
    size = 1;
    for (int t = 0; t < 2; ++t)
      if (c[t]) size += c[t]->size;
  }
};

struct reversable_node : node {
  int r;
  reversable_node() : node() { r = 0; }
  void push_down() {
    if (r) {
      std::swap(c[0], c[1]);
      ((reversable_node *)c[0])->r ^= 1, ((reversable_node *)c[1])->r ^= 1, r = 0;
    }
  }
};

template <typename T = node, int MAXSIZE = 500000>
struct tree {
  T pool[MAXSIZE + 2];
  node *root;
  int size;
  tree() {
    size = 2;
    root = pool[0], root->c[1] = pool[1], root->size = 2;
    pool[1]->f = root;
  }
  void rotate(T *n) {
    int v = n->f->c[0] == n;
    node *p = n->f, *m = n->c[v];
    p->push_down(), n->push_down();
    n->c[v] = p, p->f = n, p->c[v ^ 1] = m;
    if (m) m->f = p;
    p->update(), n->update();
  }
  void splay(T *n, T *s = nullptr) {
    while (n->f != s) {
      T *m = n->f, *l = m->f;
      if (l == s)
        rotate(n);
      else if ((l->c[0] == m) == (m->c[0] == n))
        rotate(m), rotate(n);
      else
        rotate(n), rotate(n);
    }
    if (!s) root = n;
  }
  node *new_node() { return pool[size++]; }
  void walk(node *n, int &v, int &pos) {
    n->push_down();
    int s = n->c[0] ? n->c[0]->size : 0;
    (v = s > pos) && (pos -= s + 1);
  }
  void add_node(node *n, int pos) {
    node *c = root;
    int v;
    ++pos;
    do {
      walk(c, v, pos);
    } while (c->c[v] && (c = c->c[v]));
    c->c[v] = n, n->f = cur, splay(n);
  }
  node *find(int pos, int splay = true) {
    node *c = root;
    int v;
    ++pos;
    do {
      walk(c, v, pos);
    } while (pos && (c = c->c[v]));
    if (splay) splay(c);
    return c;
  }
  node *find_range(int posl, int posr) {
    node *l = find(posl - 1), *r = find(posr, false);
    splay(r, l);
    if (r->c[0]) r->c[0]->push_down();
    return r->c[0];
  }
};

So basically we have a flag of whether a node is reversed, and when we try to rotate the tree, we push down the flag from the node to its children. This may require some understanding of the Splay Tree.

P.S.2 It is supposed to be a library, but some use cases would be like this:

#include "../template.h"

splay::tree<splay::reversable_node> s;

void dfs(splay::reversable_node *n) {
  if (n) {
    // Push down the flag.
    n->push_down();
    dfs(n->c[0]);
    // Do something about n...
    dfs(n->c[1]);
  }
}

int main() {
  // Insert 5 nodes to the Splay Tree.
  for (int i = 0; i < 5; ++i) s.add_node(s.new_node(), 0);
  // Find a range of the tree.
  splay::reversable_node *n = s.find_range(0, 3);
  // Reverse it.
  n->r = 1;
  std::swap(n->c[0], n->c[1]);
  // Traverse it in inorder.
  dfs(s.root);
}
zhtluo
  • 33
  • 4
  • Please post your whole code. How is the object of `struct` initialized? Also, what is `push_down` supposed to do? – kiner_shah Oct 31 '21 at 06:35
  • 1
    @kiner_shah Basically when we try to rotate the tree, we push down the flag from the node to its children. While there may be some other design pattern to solve the problem, tree reversion is somewhat inherent in the structure of the tree that I found it difficult to isolate. – zhtluo Oct 31 '21 at 06:44
  • Can you please also post the `main()` which initializes your data structure and calls its methods? Also, what is the point of having `node` - can't you simply use all members of `node` in `reversible_node`? – kiner_shah Oct 31 '21 at 06:47
  • Do all kinds of nodes need to inherit the same base class (ie. for polymorphism)? If not, what about [CRTP](https://en.cppreference.com/w/cpp/language/crtp)? – IWonderWhatThisAPIDoes Oct 31 '21 at 06:51
  • Consider what you can achieve if `push_down()` is a virtual function of `node` that is specialised by `reversible_node`. (That's a hint). Also, if you're going to use hackery like `c[0]->r ^= 1, c[1]->r ^= 1, r = 0` you need to explain what it does, and why that hackery is better than more readable alternatives – Peter Oct 31 '21 at 06:57
  • 1
    @kiner_shah Well yes, but I also plan to write other kinds of trees without the need of reversion. So it would be best to have a base class to work on. – zhtluo Oct 31 '21 at 06:58
  • @IWonderWhatThisAPIDoes I thought about some template ideas. The downside is that the base class can never be instantiated. Still I think this could work, but I need to consider a bit deeper on the design. – zhtluo Oct 31 '21 at 07:05
  • @Peter Since I know the exact type every time I think it is more a 'I don't want to cast it every time I use it' thing than a virtual function thing. And let us hope I never have to defend my spaghetti code for programming contests before a committee. :) – zhtluo Oct 31 '21 at 07:11
  • @zhtluo - The majority of cases that rely on downcasting can be better addressed by careful redesign with use of polymorphism and eliminating the need for down-casting. But suit yourself - you've decided you need to downcast. I do not believe it is needed. – Peter Oct 31 '21 at 07:19
  • @Peter I see what you mean. By redesigning `push_down` to some `receive_push_down` virtual function we can evoke it on the child node and eliminate the need to cast. Still this incurs some runtime cost and makes it harder to write a new node (which is probably what I will be doing during the contest). I will think about it. – zhtluo Oct 31 '21 at 07:29
  • You might also want to read up on the subject of what some gurus describe as "premature optimization". – Peter Oct 31 '21 at 07:37

1 Answers1

0

Anyway thanks to CRTP I got it to work.

namespace splay {

/**
 * Abstract node struct.
 */
template <typename T>
struct node {
  T *f, *c[2];
  int size;
  node() {
    f = c[0] = c[1] = nullptr;
    size = 1;
  }
  void push_down() {}
  void update() {
    size = 1;
    for (int t = 0; t < 2; ++t)
      if (c[t]) size += c[t]->size;
  }
};

/**
 * Abstract reversible node struct.
 */
template <typename T>
struct reversible_node : node<T> {
  int r;
  reversible_node() : node<T>() { r = 0; }
  void push_down() {
    node<T>::push_down();
    if (r) {
      for (int t = 0; t < 2; ++t)
        if (node<T>::c[t]) node<T>::c[t]->reverse();
      r = 0;
    }
  }
  void update() { node<T>::update(); }
  /**
   * Reverse the range of this node.
   */
  void reverse() {
    std::swap(node<T>::c[0], node<T>::c[1]);
    r = r ^ 1;
  }
};

template <typename T, int MAXSIZE = 500000>
struct tree {
  T pool[MAXSIZE + 2];
  T *root;
  int size;
  tree() {
    size = 2;
    root = pool, root->c[1] = pool + 1, root->size = 2;
    pool[1].f = root;
  }
  /**
   * Helper function to rotate node.
   */
  void rotate(T *n) {
    int v = n->f->c[0] == n;
    T *p = n->f, *m = n->c[v];
    if (p->f) p->f->c[p->f->c[1] == p] = n;
    n->f = p->f, n->c[v] = p;
    p->f = n, p->c[v ^ 1] = m;
    if (m) m->f = p;
    p->update(), n->update();
  }
  /**
   * Splay n so that it is under s (or to root if s is null).
   */
  void splay(T *n, T *s = nullptr) {
    while (n->f != s) {
      T *m = n->f, *l = m->f;
      if (l == s)
        rotate(n);
      else if ((l->c[0] == m) == (m->c[0] == n))
        rotate(m), rotate(n);
      else
        rotate(n), rotate(n);
    }
    if (!s) root = n;
  }
  /**
   * Get a new node from the pool.
   */
  T *new_node() { return pool + size++; }
  /**
   * Helper function to walk down the tree.
   */
  int walk(T *n, int &v, int &pos) {
    n->push_down();
    int s = n->c[0] ? n->c[0]->size : 0;
    (v = s < pos) && (pos -= s + 1);
    return s;
  }
  /**
   * Insert node n to position pos.
   */
  void insert(T *n, int pos) {
    T *c = root;
    int v;
    ++pos;
    while (walk(c, v, pos), c->c[v] && (c = c->c[v]))
      ;
    c->c[v] = n, n->f = c, splay(n);
  }
  /**
   * Find the node at position pos. If sp is true, splay it.
   */
  T *find(int pos, int sp = true) {
    T *c = root;
    int v;
    ++pos;
    while ((pos < walk(c, v, pos) || v) && (c = c->c[v]))
      ;
    if (sp) splay(c);
    return c;
  }
  /**
   * Find the range [posl, posr) on the splay tree.
   */
  T *find_range(int posl, int posr) {
    T *l = find(posl - 1), *r = find(posr, false);
    splay(r, l);
    if (r->c[0]) r->c[0]->push_down();
    return r->c[0];
  }
};

}  // namespace splay

Some use case:

struct node : splay::reversible_node<node> {
  int val;
  void push_down() { splay::reversible_node<node>::push_down(); }
  void update() { splay::reversible_node<node>::update(); }
};

splay::tree<node> t;

int N, M;

void inorder(node *n) {
  static int f = 0;
  if (!n) return;
  n->push_down();
  inorder(n->c[0]);
  if (n->val) {
    if (f) printf(" ");
    f = 1;
    printf("%d", n->val);
  }
  inorder(n->c[1]);
}

int main() {
  scanf("%d%d", &N, &M);
  for (int i = 0; i < N; ++i) {
    node *n = t.new_node();
    n->val = i + 1;
    t.insert(n, i);
  }
  for (int i = 0, u, v; i < M; ++i) {
    scanf("%d%d", &u, &v);
    node *n = t.find_range(u - 1, v);
    n->reverse();
  }
  inorder(t.root);
}

Hopefully this allows me to write Splay faster in CP.

zhtluo
  • 33
  • 4