1

I am aware that Scala has optimizations for tail-recursive functions (i.e. those functions in which the recursive call is the last thing executed by the function). What I am asking here is whether there is a way to optimize tail calls to different functions. Consider the following Scala code:

def doA(): Unit = {
  doB()
}

def doB(): Unit = {
  doA()
}

If we let this execute long enough it will give a stack overflow error which one can mitigate by allocating more stack space. Nonetheless, it will eventually exceed the allocated space and once again cause a stack overflow error. One way to mitigate this could be:

case class C(f: () => C)

def run(): Unit = {
  var c: C = C(() => doA())
  while(true){
    c = c.f.apply()
  }
}

def doA(): C = {
  C(() => doB())
}

def doB(): C = {
  C(() => doA())
}

However, this proved to be quite slow. Is there a better way to optimize this?

griz
  • 119
  • 7
  • 3
    Have you had a look at [TailCalls](https://www.scala-lang.org/files/archive/api/current/scala/util/control/TailCalls$.html) in the standard library? (Also mentioned in [this recent answer](https://stackoverflow.com/a/65863771/4993128).) – jwvh Jan 24 '21 at 08:10
  • Not really, no. I'm guessing it should be more efficient than what I suggested, right? Are there any limitations of this that I should be aware of? – griz Jan 24 '21 at 08:21
  • 2
    According to [the accompanying paper](http://blog.higher-order.com/assets/trampolines.pdf), _"...the well-known technique of trampolining...allows us to programmatically exchange stack for heap."_ So while you shouldn't get a StackOverflow, allocation will still have some finite limitation. – jwvh Jan 24 '21 at 09:02
  • Thanks for the reference. I tried running some benchmarks using TailCalls and it was still pretty slow. I guess anything that has to do with trampolining will have a considerate amount of overheads since you would be using the heap rather than the stack, as explained [here](https://gist.github.com/eamelink/4466932a11d8d92a6b76e80364062250). – griz Jan 24 '21 at 10:41
  • 1
    I doubly. All ways to implement continuations provided by libraries (and a lot of ways provided by language's comiler) rely on allocating some intermediate type which is used to suspend computation. I guess rewriting your code manually to use `while` or `@tailrec` is the only "fast" way. – Mateusz Kubuszok Jan 24 '21 at 12:15
  • 1
    What's the purpose? Is this cross-recursion a requirement? There are many different ways to create an infinite alternation between 2 (or more) functions. – jwvh Jan 24 '21 at 13:01
  • The problem is that the functions are not always called in the same order. The order is only decided at runtime. So let's assume we had another function doC() that at runtime can either call doA() or doB() and that doB() only calls doC(). So, an example trace could be: doA() -> doB() -> doC() -> doB() -> doC(). This makes it impossible to have a while loop because we wouldn't know where the point of execution should continue from. I can explain further in the main question if it is not clear. – griz Jan 24 '21 at 14:50
  • You can but you either have to use ADT and matching (= allocation) or flags, `var`s with `null`s and/or casting to pick the case and arguments. – Mateusz Kubuszok Jan 24 '21 at 20:10

1 Answers1

2

Here's one way achieve an infinite progression of method calls, without consuming stack, where each method decides which method goes next.

def doA(): () => Any = {
  doB _
}
def doB(): () => Any = {
  doC _
}
def doC(): () => Any = {
  if (util.Random.nextBoolean()) doA _
  else                           doB _
}

Iterator.iterate(doA())(_.asInstanceOf[() => () => Any]())
        .foreach(identity)
jwvh
  • 50,871
  • 7
  • 38
  • 64