In Python it would look like this:
def countNodes(root):
root.numNodes = 1
stack = [[root, 0]] # A stack entry consists of a node paired with a stage
while stack:
stack[-1][1] += 1 # Increment stage of stack top (0->1->2->3)
node, stage = stack[-1] # Peek stack top
if stage == 1 and node.left:
node.left.numNodes = 1
stack.append([node.left, 0])
elif stage == 2 and node.right:
node.right.numNodes = 1
stack.append([node.right, 0])
elif stage == 3:
stack.pop()
if stack:
stack[-1][0].numNodes += node.numNodes
This will give each node an attribute numNodes
which will end up getting the number of nodes in its subtree (including itself).
The idea is that each node gets on the stack with a "stage". The stage identifies what we have done with that node:
- Stage 0: nothing yet, first time encounter
- Stage 1: visiting its left subtree
- Stage 2: visiting its right subtree
- Stage 3: all done for this node: it is removed from the stack.