1

I'm trying the "defunctionalize the continuation" technique on some recursive functions, to see if I can get a good iterative version to pop out. Following along with The Best Refactoring You've Never Heard Of (in Lua just for convenience; this seems to be mostly language-agnostic), I did this:

-- original
function printTree(tree)
    if tree then
        printTree(tree.left)
        print(tree.content)
        printTree(tree.right)
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

-- make tail-recursive with CPS
function printTree(tree, kont)
    if tree then
        printTree(tree.left, function()
            print(tree.content)
            printTree(tree.right, kont)
        end)
    else
        kont()
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree, function() end)

-- defunctionalize the continuation
function apply(kont)
    if kont then
        print(kont.tree.content)
        printTree(kont.tree.right, kont.next)
    end
end
function printTree(tree, kont)
    if tree then
        printTree(tree.left, { tree = tree, next = kont })
    else
        apply(kont)
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

-- inline apply
function printTree(tree, kont)
    if tree then
        printTree(tree.left, { tree = tree, next = kont })
    elseif kont then
        print(kont.tree.content)
        printTree(kont.tree.right, kont.next)
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

-- perform tail-call elimination
function printTree(tree, kont)
    while true do
        if tree then
            kont = { tree = tree, next = kont }
            tree = tree.left
        elseif kont then
            print(kont.tree.content)
            tree = kont.tree.right
            kont = kont.next
        else
            return
        end
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

Then I tried the same technique on the factorial function:

-- original
function factorial(n)
    if n == 0 then
        return 1
    else
        return n * factorial(n - 1)
    end
end
print(factorial(6))

-- make tail-recursive with CPS
function factorial(n, kont)
    if n == 0 then
        return kont(1)
    else
        return factorial(n - 1, function(x)
            return kont(n * x)
        end)
    end
end
print(factorial(6, function(x) return x end))

-- defunctionalize the continuation
function apply(kont, x)
    if kont then
        return apply(kont.next, kont.n * x)
    else
        return x
    end
end
function factorial(n, kont)
    if n == 0 then
        return apply(kont, 1)
    else
        return factorial(n - 1, { n = n, next = kont })
    end
end
print(factorial(6))

Here's where things start to go wrong. The next step is to inline apply, but I can't do that since apply calls itself recursively. To keep going, I tried doing tail-call elimination on it.

-- perform tail-call elimination
function apply(kont, x)
    while kont do
        x = kont.n * x
        kont = kont.next
    end
    return x
end
function factorial(n, kont)
    if n == 0 then
        return apply(kont, 1)
    else
        return factorial(n - 1, { n = n, next = kont })
    end
end
print(factorial(6))

Okay, now we seem to be back on track.

-- inline apply
function factorial(n, kont)
    if n == 0 then
        local x = 1
        while kont do
            x = kont.n * x
            kont = kont.next
        end
        return x
    else
        return factorial(n - 1, { n = n, next = kont })
    end
end
print(factorial(6))

-- perform tail-call elimination
function factorial(n, kont)
    while n ~= 0 do
        kont = { n = n, next = kont }
        n = n - 1
    end
    local x = 1
    while kont do
        x = kont.n * x
        kont = kont.next
    end
    return x
end
print(factorial(6))

Okay, we got a fully iterative implementation of factorial, but it's a pretty bad one. I was hoping to end up with something like this instead:

function factorial(n)
    local x = 1
    while n ~= 0 do
        x = n * x
        n = n - 1
    end
    return x
end
print(factorial(6))

Is there any modification to the steps I followed that will let me mechanically end up with a function that looks more like this one?

  • This looks insanely complicated. How about `function factorial(n, acc) if n == 0 then acc else factorial(n-1, n*acc)`? – Stef Oct 20 '21 at 17:57
  • @Stef I know the snippet in your comment is a "good" recursive form of the factorial function. What I'm asking is for a way that I could mechanically derive something like that. – Joseph Sible-Reinstate Monica Oct 20 '21 at 17:59
  • I tried reading the blogpost you linked, but I got lost after two paragraphs. It doesn't help that it relies on screenshots of code, which are all blurry on my screen. Sorry. I do know more systematic ways to transform a recursive function into a tail-recursive function or into an iterative function, but they don't involve all this "kont" stuff and I don't know what *that*'s about. – Stef Oct 20 '21 at 18:02
  • 1
    You might be interested in Appel's Compiling with Continuations. – A. Webb Nov 15 '21 at 14:59

1 Answers1

2

First, congrats on having followed all the steps correctly. I've watched many people attempt defunctionalization exercises, and I've seen this rarely.

The 3-step process of CPS, defunctionalization, and tail-call elimination is not a standalone recipe that can produce clean code in all contexts without modification. You saw this when you needed to TCE the apply function itself. (Which, by the way, is also necessary if you want to use parentheses when printing trees, making that version much more complicated than the example I used in the blog post.)

In this case, there is an extra step needed to get the clean factorial function you desire. The function you wrote computes 5! as 5*(4*(3*(2*1))), and so the final continuation is function(x) return 5*(4*(3*(2*x))) end. This involves 4 multiplications, and, more generally, a continuation of this form involves an unbounded number of multiplications. This is not possible to handle without a second loop....

....UNLESS you note that that continuation is equivalent to function(x) return 120*x end. The key property is the associativity of multiplication. Indeed, compare: let's say we defined a factorial analogue with a non-associative operator like exponentiation instead of multiplication. Call powerorial, such that powerorial(5)=5^(4^(3^(2^1))). Then it would not be possible to simplify away the loop.

Here, you need to show that apply(kont, x) is equivalent to acc*x where acc is equivalent to the partial factorial contained in kont. E.g.: for the defunctionalized continuation kont={n=3, next={n=4, next={n=5, next=nil}}}, show that apply(kont, x) is equivalent to to 60*x. This is a simple inductive proof using associativity.

The insight that associativity in its various forms is a key step to getting clean iterative functions with accumulators applies to many problems, enough that Jeremy Gibbons wrote a paper on this topic called Continuation-Passing Style, Defunctionalization, Accumulations, and Associativity — with factorial as his first example!

This is a all part of the larger area of mechanical manipulation of programs, a topic practiced more deeply in our Advanced Software Design Course. If you want deeper than that, then you can also start reading papers on the topic of program manipulation and equational reasoning, including the works of Jeremy Gibbons, Richard Bird, Zohar Manna, and Bill Scherlis.

James Koppel
  • 1,587
  • 1
  • 10
  • 16