4

Let me clarify my question by example. This is a standard exponentiation algorithm written with tail recursion in Scala:

def power(x: Double, y: Int): Double = {
  def sqr(z: Double): Double = z * z
  def loop(xx: Double, yy: Int): Double = 
    if (yy == 0) xx
    else if (yy % 2 == 0) sqr(loop(xx, yy / 2))
    else loop(xx * x, yy - 1)

  loop(1.0, y)
}

Here sqr method is used to produce the square of loop's result. It doesn't look like a good idea - to define a special function for such a simple operation. But, we can't write just loop(..) * loop(..) instead, since it doubles the calculations.

We also can write it with val and without sqr function:

def power(x: Double, y: Int): Double = {
  def loop(xx: Double, yy: Int): Double = 
    if (yy == 0) xx
    else if (yy % 2 == 0) { val s = loop(xx, yy / 2); s * s }
    else loop(xx * x, yy - 1)

  loop(1.0, y)
}

I can't say that it looks better then variant with sqr, since it uses state variable. The first case is more functional the second way is more Scala-friendly.

Anyway, my question is how to deal with cases when you need to postprocess function's result? Maybe Scala has some other ways to achieve that?

Richard Wеrеzaк
  • 1,551
  • 1
  • 13
  • 17
Vladimir Kostyukov
  • 2,492
  • 3
  • 21
  • 30
  • 8
    Just to be clear, you say it is "written with tail recursion", but in fact, `loop` can not be tail-call optimised, because of the call to `loop` that is not in tail position. So it will consume stack like any other recursive function. – Ben James Jul 04 '13 at 11:13
  • Note that a `val` is not a "state variable"; `val`s are immutable (so long, of course, as the type of the value is immutable). The approach with `val` is equivalent to what would probably get used in a truly purely functional language such as Haskell or Coq (there, the equivalent to `val x = ...; ...` is spelled `let x = ... in ...`). – Antal Spector-Zabusky Jul 12 '13 at 20:51

4 Answers4

6

You are using the law that

x^(2n) = x^n * x^n

But this is the same as

x^n * x^n = (x*x)^n

Hence, to avoid squaring after recursion, the value in the case where y is even should be like displayed below in the code listing.

This way, tail-calling will be possible. Here is the full code (not knowing Scala, I hope I get the syntax right by analogy):

def power(x: Double, y: Int): Double = {
    def loop(xx: Double, acc: Double, yy: Int): Double = 
      if (yy == 0) acc
      else if (yy % 2 == 0) loop(xx*xx, acc, yy / 2)
      else loop(xx, acc * xx, yy - 1)

    loop(x, 1.0, y)
}

Here it is in a Haskell like language:

power2 x n = loop x 1 n 
    where 
        loop x a 0 = a 
        loop x a n = if odd n then loop x    (a*x) (n-1) 
                              else loop (x*x) a    (n `quot` 2)
Ingo
  • 36,037
  • 5
  • 53
  • 100
  • 1
    The idea is great, but we can't use it in this meaning we have to do that in bottom-up manner (not to pass as argument, but perform as postporcessing). I just tried it. But the idea is briliant. – Vladimir Kostyukov Jul 04 '13 at 12:58
  • It's not that brilliant, actually I believe this is the standard way to implement integral power. But I understand that you look for a way to post-process the value where post-processing can not be avoided that way. – Ingo Jul 04 '13 at 13:03
  • Ingo, could you please edit your answer with improved `power` function by your suggestion? I'm just saying, that I'm not sure that this is possible in some way. – Vladimir Kostyukov Jul 04 '13 at 13:09
  • You're right, Vladimir, actually one must also pass the original `x` through the loop. – Ingo Jul 04 '13 at 13:27
  • Ingo, please, change the else branch to `else loop(xx, acc * xx, yy - 1)`. There is an error. – Vladimir Kostyukov Jul 04 '13 at 13:29
  • This solution is great (what a nice catch) for concreet example of code. But, I was looking for a general solution. Have to mark other answer as proper one. Sorry. – Vladimir Kostyukov Jul 04 '13 at 16:46
  • This is the solution I lean toward also. What you'be done here is define an invariant (acc * (xx ^ yy)) = x ^ y, and passed it along in the tail recursion. Many problems can be solved in a similar manner. – WorBlux Jul 07 '13 at 18:32
5

You could use a "forward pipe". I've got this idea from here: Cache an intermediate variable in an one-liner.

So

val s = loop(xx, yy / 2); s * s

could be rewritten to

loop(xx, yy / 2) |> (s => s * s)

using an implicit conversion like this

implicit class PipedObject[A](value: A) {
  def |>[B](f: A => B): B = f(value)
}

As Petr has pointed out: Using an implicit value class

object PipedObjectContainer {
  implicit class PipedObject[A](val value: A) extends AnyVal {
    def |>[B](f: A => B): B = f(value)
  }
}

to be used like this

import PipedObjectContainer._
loop(xx, yy / 2) |> (s => s * s)

is better, since it does not need a temporary instance (requires Scala >= 2.10).

Community
  • 1
  • 1
Beryllium
  • 12,808
  • 10
  • 56
  • 86
  • 1
    Just a note - with Scala 2.10 this can be made slightly more efficient by defining `implicit class PipedObject[A](val value: A) extends AnyVal ...`. This creates a [user-defined value class](http://www.scala-lang.org/api/current/index.html#scala.AnyVal) which is treated specially by the compiler and avoids allocation of new objects at runtime. – Petr Jul 04 '13 at 13:08
  • I just found out this operator is [defined in Scalaz](https://github.com/scalaz/scalaz/blob/scalaz-seven/core/src/main/scala/scalaz/syntax/IdOps.scala), just `import scalaz.syntax.id._` (although it seems it doesn't use that Scala 2.10 feature so it does allocate new objects at runtime). See also http://stackoverflow.com/a/17450595/1333025 – Petr Jul 04 '13 at 15:48
2

In my comment I pointed out that your implementations can't be tail call optimised, because in the case where yy % 2 == 0, there is a recursive call that is not in tail position. So, for a large input, this can overflow the stack.

A general solution to this is to trampoline your function, replacing recursive calls with data which can be mapped over with "post-processing" such as sqr. The result is then computed by an interpreter, which steps through the return values, storing them on the heap rather than the stack.

The Scalaz library provides an implementation of the data types and interpreter.

import scalaz.Free.Trampoline, scalaz.Trampoline._

def sqr(z: Double): Double = z * z

def power(x: Double, y: Int): Double = {
  def loop(xx: Double, yy: Int): Trampoline[Double] =
    if (yy == 0)
      done(xx)
    else if (yy % 2 == 0)
      suspend(loop(xx, yy / 2)) map sqr
    else
      suspend(loop(xx * x, yy - 1))

  loop(1.0, y).run
}

There is a considerable performance hit for doing this, though. In this particular case, I would use Igno's solution to avoid the need to call sqr at all. But, the technique described above can be useful when you can't make such optimisations to your algorithm.

Ben James
  • 121,135
  • 26
  • 193
  • 155
0

In this particular case

  • No need for utility functions
  • No need for obtuse piping / implicits
  • Only need a single standalone recursive call at end - to always give tail recursion

    def power(x: Double, y: Int): Double = 
      if (y == 0) x
      else {
        val evenPower = y % 2 == 0
        power(if (evenPower) x * x else x, if (evenPower) y / 2 else y - 1)
      }
    
Glen Best
  • 22,769
  • 3
  • 58
  • 74