I'm trying to write an algorithm in Python to get a unique list of all nodes in a tree where the path reaches a certain depth.
- Each child has an unknown number of children prior to traversal
- The children can be accessed via an iterable (e.g.
for child in B.get_children()
)
For example, see this tree (asterisks mark node that should be included):
A*
|
-----
| |
B* C*
| |
| ---
| | |
D* E F*
/ | \ | \
G* H* I* J* K*
|
L
Let's say I'm trying to reach a depth of 3. I need a function that would yield the sequence [G, H, I, J, K, D, F, B, C, A]
in any order.
Note the omission of:
E
(doesn't reach depth of 3)L
(exceeds depth of 3)
I feel there is a way to get this list recursively. Something along the lines of:
def iterate_tree(path: List[T], all_nodes: Set[T]):
if len(path) == 4:
for node in path:
if node not in all_nodes:
all_nodes.add(node)
yield node
else:
for node in path[-1].get_children():
path = path.copy()
path.append(node)
yield from iterate_tree(path, all_nodes)
iterate_tree([A] ,set())
I don't know if the above works, but I think I can hack it from that. What I don't like about the (probably incorrect) solution is:
- The recursion: I'm writing this for an unknown depth. I don't want a stack-overflow.
- I really feel like there must be a way to do this without carrying around a
set
of previouslyyielded
nodes. - I have to make a copy of
path
at each iteration so I don't mess up other branches of the recursion.
Any suggestions?