2

I'm trying to write an algorithm in Python to get a unique list of all nodes in a tree where the path reaches a certain depth.

  • Each child has an unknown number of children prior to traversal
  • The children can be accessed via an iterable (e.g. for child in B.get_children())

For example, see this tree (asterisks mark node that should be included):

       A*
       |
     -----
    |     |
    B*    C*
    |     |
    |    ---
    |   |   |
    D*  E   F*
  / | \     | \
 G* H* I*   J* K*
               |
               L

Let's say I'm trying to reach a depth of 3. I need a function that would yield the sequence [G, H, I, J, K, D, F, B, C, A] in any order.

Note the omission of:

  • E (doesn't reach depth of 3)
  • L (exceeds depth of 3)

I feel there is a way to get this list recursively. Something along the lines of:

def iterate_tree(path: List[T], all_nodes: Set[T]):
    if len(path) == 4:
        for node in path:
            if node not in all_nodes:
                all_nodes.add(node)
                yield node
    else:
        for node in path[-1].get_children():
            path = path.copy()
            path.append(node)
            yield from iterate_tree(path, all_nodes)

iterate_tree([A] ,set())

I don't know if the above works, but I think I can hack it from that. What I don't like about the (probably incorrect) solution is:

  1. The recursion: I'm writing this for an unknown depth. I don't want a stack-overflow.
  2. I really feel like there must be a way to do this without carrying around a set of previously yielded nodes.
  3. I have to make a copy of path at each iteration so I don't mess up other branches of the recursion.

Any suggestions?

user276833
  • 77
  • 6

1 Answers1

0

For point 1, you can use an explicit stack and loop instead of recursion.

For point 2, I'm not sure I see a problem with keeping a set of yielded nodes. Memory is cheap and if you need to detect duplicates, re-traversing the tree every time you yield is extremely expensive.

Furthermore, your implementation checks for uniqueness based on node hashability, but it's unclear how nodes compute their hash. I assume you're using Node.val for that. If you're hashing based on object reference, "uniqueness" seems pointless since you're guaranteed that a tree of Node objects is unique by identity. The example here doesn't show what a clash on uniqueness would entail. My implementation assumes the hash is object identity (as it should be) and that you can access the value for uniqueness separately using Node.val.

For point 3, if you're working recursively there's no need to copy the path list since you revisit the call frame and can append/pop on a single list. Iteratively, you can keep a parent_of dict alongside the nodes_yielded set that keeps a reference to the parent of each node. When we reach a node at the desired depth, we can walk the links in this dictionary to reconstruct the path, avoiding revisiting a branch more than once thanks to nodes_yielded. A second set, vals_yielded can be used to enforce uniqueness on the yields.

Lastly, I don't really know what your data structures are so in the interest of a minimal, complete example, I've provided something that should be adaptable for you.

import collections

def unique_nodes_on_paths_to_depth(root, depth):
    parent_of = {root: None}
    nodes_yielded = set()
    vals_yielded = set()
    stack = [(root, depth)]

    while stack:
        node, depth = stack.pop()

        if depth == 0:
            while node and node not in nodes_yielded:
                if node.val not in vals_yielded:
                    vals_yielded.add(node.val)
                    yield node

                nodes_yielded.add(node)
                node = parent_of[node]
        elif depth > 0:
            for child in node.children:
                parent_of[child] = node
                stack.append((child, depth - 1))

if __name__ == "__main__":
    """
           A*
           |
         -----
        |     |
        B*    C*
        |     |
        |    ---
        |   |   |
        D*  E   F*
      / | \     | \
     G* H* I*   J* K*
                   |
                   L
    """
    Node = collections.namedtuple("Node", "val children")
    root = Node("A", (
        Node("B", (
            Node("D", (
                Node("G", ()),
                Node("H", ()),
                Node("I", ()),
            ))
        ,)),
        Node("C", (
            Node("E", ()),
            Node("F", (
                Node("J", ()),
                Node("K", (
                    Node("L", ())
                ,)),
            )),
        ))
    ))
    print([x.val for x in unique_nodes_on_paths_to_depth(root, 3)])
    #      => ['K', 'F', 'C', 'A', 'J', 'I', 'D', 'B', 'H', 'G'] 
ggorlen
  • 44,755
  • 7
  • 76
  • 106