18

I'm writing a path tracer in C++ and I'd like to try and implement the most resource-intensive code into CUDA or OpenCL (I'm not sure which one to pick).

I've heard that my graphics card's version of CUDA doesn't support recursion, which is something my path tracer utilizes heavily.

As I have it coded both in Python and C++, I'll post some simplified Python code for readability:

def Trace(ray):
  hit = what_object_is_hit(ray)

  if not hit:
    return Color(0, 0, 0)

  newRay = hit.bouceChildRayOffSurface(ray)

  return hit.diffuse * (Trace(newRay) + hit.emittance)

I tried manually unrolling the function, and there is a definite pattern (d is diffuse and e is emittance):

Level 1:  d1 * e1 

Level 2:  d1 * d2 * e2
        + e1

Level 3:  d1 * d2 * d3 * e3
        + d1 * d2 * e2
        + e1

Level 4:  d1 * d2 * d3 * d4 * e4
        + d1 * d2 * d3 * e3
        + d1 * d2 * e2
        + e1

I might be wrong, though...

My question is, how would I go about implementing this code in a while loop?

I was thinking using something of this format:

total = Color(0, 0, 0)
n = 1

while n < 10:   # Maximum recursion depth
  result = magical_function()

  if not result:  break

  total += result
  n += 1

I've never really dealt with the task of unraveling a recursive function before, so any help would be greatly appreciated. Thanks!

Blender
  • 289,723
  • 53
  • 439
  • 496

3 Answers3

21

In a recursive function, each time a recursive call occurs, the state of the caller is saved to a stack, then restored when the recursive call is complete. To convert a recursive function to an iterative one, you need to turn the state of the suspended function into an explicit data structure. Of course, you can create your own stack in software, but there are often tricks you can use to make your code more efficient.

This answer works through the transformation steps for this example. You can apply the same methods to other loops.

Tail Recursion Transformation

Let's take a look at your code again:

def Trace(ray):
  # Here was code to look for intersections

  if not hit:
      return Color(0, 0, 0)

  return hit.diffuse * (Trace(ray) + hit.emittance)

In general, a recursive call has to go back to the calling function, so the caller can finish what it's doing. In this case, the caller "finishes" by performing an addition and a multiplication. This produces a computation like d1 * (d2 * (d3 * (... + e3) + e2) + e1)). We can take advantage of the distributive law of addition and the associative laws of multiplication and addition to transform the calculation into [d1 * e1] + [(d1 * d2) * e2] + [(d1 * d2) * d3) * e3] + ... . Note that the first term in this series only refers to iteration 1, the second only refers to iterations 1 and 2, and so forth. That tells us that we can compute this series on the fly. Moreover, this series contains the series (d1, d1*d2, d1*d2*d3, ...), which we can also compute on the fly. Putting that back into the code:

def Trace(diffuse, emittance, ray):
  # Here was code to look for intersections

  if not hit: return emittance                            # The complete value has been computed

  new_diffuse = diffuse * hit.diffuse                     # (...) * dN
  new_emittance = emittance + new_diffuse * hit.emittance # (...) + [(d1 * ... * dN) + eN]
  return Trace(new_diffuse, new_emittance, ray)

Tail Recursion Elimination

In the new loop, the caller has no work to do after the callee finishes; it simply returns the callee's result. The caller has no work to finish, so it doesn't have to save any of its state! Instead of a call, we can overwrite the old parameters and go back to the beginning of the function (not valid Python, but it illustrates the point):

def Trace(diffuse, emittance, ray):
  beginning:
  # Here was code to look for intersections

  if not hit: return emittance                            # The complete value has been computed

  new_diffuse = diffuse * hit.diffuse                     # (...) * dN
  new_emittance = emittance + new_diffuse * hit.emittance # (...) + [(d1 * ... * dN) + eN]
  (diffuse, emittance) = (new_diffuse, new_emittance)
  goto beginning

Finally, we have transformed the recursive function into an equivalent loop. All that's left is to express it in Python syntax.

def Trace(diffuse, emittance, ray):
  while True:
    # Here was code to look for intersections

    if not hit: break

    diffuse = diffuse * hit.diffuse                 # (...) * dN
    emittance = emittance + diffuse * hit.emittance # (...) + [(d1 * ... * dN) + eN]

  return emittance
Heatsink
  • 7,721
  • 1
  • 25
  • 36
  • perhaps a decorator which encapsulates this pattern may also be useful http://code.activestate.com/recipes/496691-new-tail-recursion-decorator/ – ninjagecko Jun 10 '11 at 02:04
3

You're in luck. Your code uses tail recursion, which is when you use recursion as the last thing in your function. The compiler can normally do it for you, but you'll have to do it manually here:

total = Color(0, 0, 0)
mult = 1
n = 1

while n < 10:   # Maximum recursion depth
  # Here was code to look for intersections

  if not hit: break

  total += mult * hit.diffuse * hit.emittance
  mult *= hit.diffuse
  n += 1

return total
r2jitu
  • 177
  • 1
  • 4
  • The original code does not use tail recursion. It performs an add and a multiply after the recursive call. You can make it tail recursive, but that involves some nontrivial assumptions that a compiler probably won't make. Since the question was about how to eliminate recursion, I think it's important not to sweep those assumptions under the rug. – Heatsink Jun 10 '11 at 01:18
  • 1
    Okay, I agree that it's not really tail recursion, but this treats it as such and it should still work. – r2jitu Jun 10 '11 at 01:22
1

Generally you can always represent recursion with a stack.

For example:

stack.push(Color(0,0,0), ray, 0) // color, ray, level#
while (!stack.empty()):
   current = stack.pop()
   if (current.level == 10): break
   // compute hit, and newray from current.ray
   stack.push(hit.diffuse*(current.color + hit.emittance), newray, current.level+1)
return current

Essentially recursion works by pushing the arguments of the function on to the stack and calling the function again with new arguments. You just have to emulate this using a stack.

Himadri Choudhury
  • 10,217
  • 6
  • 39
  • 47
  • Hmm, I'm a bit confused. `stack` contains both a `hit` (from `.pop()`), *and* the color value? `hit` is an object which randomly differs each time the `while` loop iterates (the `ray` reflects off of it and might hit another object, which then becomes `hit`). – Blender Jun 10 '11 at 00:33
  • Ok. I was confused by the different ray and color parameters. I updated the answer. – Himadri Choudhury Jun 10 '11 at 00:48