7

I'm trying to implement an iterator class for not-necessarily-binary trees in Python. After the iterator is constructed with a tree's root node, its next() function can be called repeatedly to traverse the tree in depth-first order (e.g., this order), finally returning None when there are no nodes left.

Here is the basic Node class for a tree:

class Node(object):

    def __init__(self, title, children=None):
        self.title = title
        self.children = children or []
        self.visited = False   

    def __str__(self):
        return self.title

As you can see above, I introduced a visited property to the nodes for my first approach, since I didn't see a way around it. With that extra measure of state, the Iterator class looks like this:

class Iterator(object):

    def __init__(self, root):
        self.stack = []
        self.current = root

    def next(self):
        if self.current is None:
            return None

        self.stack.append(self.current)
        self.current.visited = True

        # Root case
        if len(self.stack) == 1:
            return self.current

        while self.stack:
            self.current = self.stack[-1] 
            for child in self.current.children:
                if not child.visited:
                    self.current = child
                    return child

            self.stack.pop()

This is all well and good, but I want to get rid of the need for the visited property, without resorting to recursion or any other alterations to the Node class.

All the state I need should be taken care of in the iterator, but I'm at a loss about how that can be done. Keeping a visited list for the whole tree is non-scalable and out of the question, so there must be a clever way to use the stack.

What especially confuses me is this--since the next() function, of course, returns, how can I remember where I've been without marking anything or using excess storage? Intuitively, I think of looping over children, but that logic is broken/forgotten when the next() function returns!

UPDATE - Here is a small test:

tree = Node(
    'A', [
        Node('B', [
            Node('C', [
                Node('D')
                ]),
            Node('E'),
            ]),
        Node('F'),
        Node('G'),
        ])

iter = Iterator(tree)

out = object()
while out:
    out = iter.next()
    print out
norman
  • 5,128
  • 13
  • 44
  • 75
  • Keeping a visited *list* might be non-scalable, but what about a visited set, e.g. based on Node object id? – michaelb Oct 01 '14 at 16:10
  • That could still potentially hold every label, though. I want the iterator to keep only a subset of the tree at a time. – norman Oct 01 '14 at 16:12
  • What is the expected output of the "small test"? – Robᵩ Oct 01 '14 at 16:49
  • It should give `A B C D E F G None` . BTW, implementing @mgilson's generator solutions as the `next()` bodies results in an infinite loop, but that might just be my poor adaptation/understanding of generators. – norman Oct 01 '14 at 16:52
  • 1
    I'm not sure how you're trying to do it, but I've updated my answer to show it passing your tests... (Note that I've coded it assuming that you'd traverse it in a `for` loop, not some sort of while loop. I don't see your while loop working anywhere since you don't catch the `StopIteration` exception. – mgilson Oct 01 '14 at 17:08

2 Answers2

9

If you really must avoid recursion, this iterator works:

from collections import deque

def node_depth_first_iter(node):
    stack = deque([node])
    while stack:
        # Pop out the first element in the stack
        node = stack.popleft()
        yield node
        # push children onto the front of the stack.
        # Note that with a deque.extendleft, the first on in is the last
        # one out, so we need to push them in reverse order.
        stack.extendleft(reversed(node.children))

With that said, I think that you're thinking about this too hard. A good-ole' (recursive) generator also does the trick:

class Node(object):

    def __init__(self, title, children=None):
        self.title = title
        self.children = children or []

    def __str__(self):
        return self.title

    def __iter__(self):
        yield self
        for child in self.children:
            for node in child:
                yield node

both of these pass your tests:

expected = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
# Test recursive generator using Node.__iter__
assert [str(n) for n in tree] == expected

# test non-recursive Iterator
assert [str(n) for n in node_depth_first_iter(tree)] == expected

and you can easily make Node.__iter__ use the non-recursive form if you prefer:

def __iter__(self):
   return node_depth_first_iter(self)
mgilson
  • 300,191
  • 65
  • 633
  • 696
0

That could still potentially hold every label, though. I want the iterator to keep only a subset of the tree at a time.

But you already are holding everything. Remember that an object is essentially a dictionary with an entry for each attribute. Having self.visited = False in the __init__ of Node means you are storing a redundant "visited" key and False value for every single Node object no matter what. A set, at least, also has the potential of not holding every single node ID. Try this:

class Iterator(object):
    def __init__(self, root):
        self.visited_ids = set()
        ...

    def next(self):
        ...
        #self.current.visited = True
        self.visited_ids.add(id(self.current))
        ...
                #if not child.visited:
                if id(child) not in self.visited_ids:

Looking up the ID in the set should be just as fast as accessing a node's attribute. The only way this can be more wasteful than your solution is the overhead of the set object itself (not its elements), which is only a concern if you have multiple concurrent iterators (which you obviously don't, otherwise the node visited attribute couldn't be useful to you).

nmclean
  • 7,564
  • 2
  • 28
  • 37