10

I've wrote a naïve test-bed to measure the performance of three kinds of factorial implementation: loop based, non tail-recursive and tail-recursive.

Surprisingly to me the worst performant was the loop ones («while» was expected to be more efficient so I provided both) that cost almost twice than the tail recursive alternative.

*ANSWER: fixing the loop implementation avoiding the = operator which outperform worst with BigInt due to its internals «loops» became fastest as expected

Another «woodoo» behavior I've experienced was the StackOverflow exception which wasn't thrown systematically for the same input in the case of non-tail recursive implementation. I can circumvent the StackOverlow by progressively call the function with larger and larger values… I feel crazy :) Answer: JVM require to converge during startup, then behavior is coherent and systematic

This is the code:

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    var i = 1
    while (i <= n) {
      accumulator = i * accumulator
      i += 1
    }
    accumulator
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(n: Int, acc: Out): Out = n match {
      case _ if n == 1 => acc
      case _ => fac(n-1, n * acc)
    }

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(i: Int, acc: Out): Out = n match {
      case _ if i == n => n * acc
      case _ => fac(i+1, i * acc)
    }

    fac(1, 1)
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = false) =
      showOutput match {
        case true => printf("%s returned %s in %d ms\n", msg, data._2.toString, data._1)
        case false => printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    showOutput ("By for loop", measure(()=>calculateByForLoop(n)))
    showOutput ("By while loop", measure(()=>calculateByWhileLoop(n)))
    showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n)))
    showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n)))
    showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n)))
  }
}

What follows is some output from sbt console (Before «while» implementation):

scala> example.Factorial.comparePerformance(10000)
By loop in 3 ns
By non-tail recursion in >>>>> StackOverflow!!!!!… see later!!!
........

scala> example.Factorial.comparePerformance(1000)
By loop in 3 ms
By non-tail recursion in 1 ms
By tail recursion in 4 ms

scala> example.Factorial.comparePerformance(5000)
By loop in 105 ms
By non-tail recursion in 27 ms
By tail recursion in 34 ms

scala> example.Factorial.comparePerformance(10000)
By loop in 236 ms
By non-tail recursion in 106 ms     >>>> Now works!!!
By tail recursion in 127 ms

scala> example.Factorial.comparePerformance(20000)
By loop in 977 ms
By non-tail recursion in 495 ms
By tail recursion in 564 ms

scala> example.Factorial.comparePerformance(30000)
By loop in 2285 ms
By non-tail recursion in 1183 ms
By tail recursion in 1281 ms

What follows is some output from sbt console (After «while» implementation):

scala> example.Factorial.comparePerformance(10000)
By for loop in 252 ms
By while loop in 246 ms
By non-tail recursion in 130 ms
By tail recursion in 136 ns

scala> example.Factorial.comparePerformance(20000)
By for loop in 984 ms
By while loop in 1091 ms
By non-tail recursion in 508 ms
By tail recursion in 560 ms

What follows is some output from sbt console (after «upward» tail recursion implementation) the world come back sane:

scala> example.Factorial.comparePerformance(10000)
By for loop in 259 ms
By while loop in 229 ms
By non-tail recursion in 114 ms
By tail recursion in 119 ms
By tail recursion upward in 105 ms

scala> example.Factorial.comparePerformance(20000)
By for loop in 1053 ms
By while loop in 957 ms
By non-tail recursion in 513 ms
By tail recursion in 565 ms
By tail recursion upward in 470 ms

What follows is some output from sbt console after fixing BigInt multiplication in «loops»: the world is totally sane:

    scala> example.Factorial.comparePerformance(20000)
By for loop in 498 ms
By while loop in 502 ms
By non-tail recursion in 521 ms
By tail recursion in 611 ms
By tail recursion upward in 503 ms

BigInt overhead and a stupid implementation by me masked the expected behavior.

PS.: In the end I should re-title this post to «A lernt lesson on BigInts»

Jason Aller
  • 3,541
  • 28
  • 38
  • 38
Lord of the Goo
  • 1,214
  • 15
  • 31
  • 8
    what is the question? – Dan Feb 28 '13 at 15:01
  • 1
    "Loops should be avoided" This is misleading, for-loops are generally far slower than equivalent while loops in scala. Tail recursion is usually faster than non-tail recursion, I'm going to go out on a limb and say it's slower because of the closure you're creating before the function starts. – Score_Under Feb 28 '13 at 15:03
  • 1
    Who can help me understand this behaviour pointing me to an authoritative reference. – Lord of the Goo Feb 28 '13 at 15:03
  • @Score_Under: I dont' see any closure in act in the tail recursive implementation, which one do you mean? – Lord of the Goo Feb 28 '13 at 15:06
  • Rex nailed the main issues, but you should research benchmark on the JVM, because you are doing it wrong. It is very, very hard to properly do benchmarks on the JVM, and it's hard to extrapolate anything useful from their results (what applies to microbenchmarks does not apply in a larger context). Of note here, you are not garbage collecting before starting the tests, which is bound to produce large variations. – Daniel C. Sobral Feb 28 '13 at 15:13
  • @DanielC.Sobral - I'm not sure it's wrong. What is `Timer.measure`? That could be doing things right for all I know. – Rex Kerr Feb 28 '13 at 15:16
  • I copied-and-pasted this into my IDE and found the timings for tail and non-tail recursion to fluctuate and overtake each other constantly. However, I'd had to define Timer.measure myself. I then altered Timer.measure to take an average of the time taken over 1000 runs (calculating factorial of 100) and it removed most of the fluctuation, giving the result of 101.4μs for the range iterator, 25.3μs for the non-tail recursion, and 23.3μs for tail recursion. (subsequent runs fluctuated by at most 0.4μs either way.) edit - adding System.gc() before each timing widened the gap. 94μs vs 26μs vs 20μs – Score_Under Feb 28 '13 at 15:19
  • "Loops should be avoided" is a horrible over-generalization. In some cases (the proverbial "inner loop") where the cost of the body of the iterative computation is itself cheap and the iteration overhead is significant (possibly dominant), it might be justified to recast the code using a less concise formulation, but it's foolish to simply "avoid for comprehensions!" – Randall Schulz Feb 28 '13 at 15:20
  • I would recommend you to have a look at [this article](http://www.ibm.com/developerworks/java/library/j-benchmark1/index.html) to know which pitfalls you may stumble upon during jvm-language benchmarking. – om-nom-nom Feb 28 '13 at 15:24
  • @om-nom-nom - It's worth pointing out that in this case almost all the effects LotG measured are _true_. It's just that reasonable-seeming assumptions were violated (like a*b and b*a are the same speed regardless of contents). The only measurement-affected-your-answer thing was the stack overflow, which got optimized away. – Rex Kerr Feb 28 '13 at 15:40
  • By the way, I don't know what it is that you're measuring, but it's _not_ nanoseconds. Maybe milliseconds? – Rex Kerr Feb 28 '13 at 15:42
  • @RexKerr What `Timer.measure`? The code above has the implementation of `comparePerformance`, `showOutput` and `measure`? Maybe we were looking at different edits? – Daniel C. Sobral Feb 28 '13 at 18:32
  • @DanielC.Sobral - Yeah, with this edit I can see what you were worried about. – Rex Kerr Feb 28 '13 at 19:33

3 Answers3

12

For loops are not actually quite loops; they're for comprehensions on a range. If you actually want a loop, you need to use while. (Actually, I think the BigInt multiplication here is heavyweight enough so it shouldn't matter. But you'll notice if you're multiplying Ints.)

Also, you have confused yourself by using BigInt. The bigger your BigInt is, the slower your multiplication. So your non-tail loop counts up while your tail recursion loop counds down which means that the latter has more big numbers to multiply.

If you fix these two issues you will find that sanity is restored: loops and tail recursion are the same speed, with both regular recursion and for slower. (Regular recursion may not be slower if the JVM optimization makes it equivalent)

(Also, the stack overflow fix is probably because the JVM starts inlining and may either make the call tail-recursive itself, or unrolls the loop far enough so that you don't overflow any longer.)

Finally, you're getting poor results with for and while because you're multiplying on the right rather than the left with the small number. It turns out that the Java's BigInt multiplies faster with the smaller number on the left.

Rex Kerr
  • 166,841
  • 26
  • 322
  • 407
  • I'm confused. By looking at the code I would say that both count down. Is it due to tail recursion that the order has somehow changed? – bluenote10 Feb 28 '13 at 15:13
  • I've just add a «while» implementation that outperform as the «for»… gimme 1' to integrate your perl! – Lord of the Goo Feb 28 '13 at 15:14
  • 1
    @bluenote10 - Regular recursion calculates the deepest value first. Tail recursion does the opposite, as written. Write out a few terms and you'll see it. – Rex Kerr Feb 28 '13 at 15:15
  • And what your opinion about missing of determinism in throwing StackOverflow exceptions?… I'm still of the opinion that some kind of memoization works under the hood – Lord of the Goo Feb 28 '13 at 15:17
  • @LordoftheGoo - I added that to the answer. It's hard to know for sure without a JVM kitted out to print the assembly during inlining. – Rex Kerr Feb 28 '13 at 15:18
  • @RexKerr: Thanks for making that clear. From your answer I would take it your claim was that the recursion logic itself was different... – bluenote10 Feb 28 '13 at 15:21
  • What still holds — and, honestly, I haven't figure out by your replies — is the bad performance of loop implementations over the tail recursive – Lord of the Goo Feb 28 '13 at 16:30
  • @LordoftheGoo - Please see my "finally"--the last paragraph. It has to do with noncommutative timing in multiplication. – Rex Kerr Feb 28 '13 at 16:38
  • Thanx Rex, I've fixed: I should consider to change the title to this post that essentially sports my incompetency in using BigInt – Lord of the Goo Feb 28 '13 at 16:53
2

Scala static methods for factorial(n) (coded with scala 2.12.x, java-8):

object Factorial {

  /*
   * For large N, it throws a stack overflow
   */
  def recursive(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      n * recursive(n - 1)
    }
  }

  /*
   * A tail recursive method is compiled to avoid stack overflow
   */
  @scala.annotation.tailrec
  def recursiveTail(n:BigInt, acc:BigInt = 1): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      acc
    } else {
      recursiveTail(n - 1, n * acc)
    }
  }

  /*
   * A while loop
   */
  def loop(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      var acc = 1
      var idx = 1
      while(idx <= n) {
        acc = idx * acc
        idx += 1
      }
      acc
    }
  }

}

Specs:

class FactorialSpecs extends SpecHelper {

  private val smallInt = 10
  private val largeInt = 10000

  describe("Factorial.recursive") {
    it("return 1 for 0") {
      assert(Factorial.recursive(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursive(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursive(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursive(smallInt) == 3628800)
    }
    it("throws StackOverflow for large inputs") {
      intercept[java.lang.StackOverflowError] {
        Factorial.recursive(Int.MaxValue)
      }
    }
  }

  describe("Factorial.recursiveTail") {
    it("return 1 for 0") {
      assert(Factorial.recursiveTail(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursiveTail(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursiveTail(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursiveTail(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.recursiveTail(largeInt).isInstanceOf[BigInt])
    }
  }

  describe("Factorial.loop") {
    it("return 1 for 0") {
      assert(Factorial.loop(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.loop(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.loop(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.loop(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.loop(largeInt).isInstanceOf[BigInt])
    }
  }
}

Benchmarks:

import org.scalameter.api._

class BenchmarkFactorials extends Bench.OfflineReport {

  val gen: Gen[Int] = Gen.range("N")(1, 1000, 100) // scalastyle:ignore

  performance of "Factorial" in {
    measure method "loop" in {
      using(gen) in {
        n => Factorial.loop(n)
      }
    }
    measure method "recursive" in {
      using(gen) in {
        n => Factorial.recursive(n)
      }
    }
    measure method "recursiveTail" in {
      using(gen) in {
        n => Factorial.recursiveTail(n)
      }
    }
  }

}

Benchmark results (loop is much faster):

[info] Test group: Factorial.loop
[info] - Factorial.loop.Test-9 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.01 ms, ci = <0.00 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.01 ms, ci = <0.01 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.02 ms, ci = <0.02 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.03 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.03 ms, ci = <0.03 ms, 0.04 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.04 ms, ci = <0.04 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.05 ms, ci = <0.05 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.06 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.07 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursive
[info] - Factorial.recursive.Test-10 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.05 ms, ci = <0.01 ms, 0.09 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.07 ms, ci = <0.00 ms, 0.13 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.09 ms, ci = <0.01 ms, 0.18 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.10 ms, ci = <0.03 ms, 0.17 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.11 ms, ci = <0.08 ms, 0.15 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.13 ms, ci = <0.11 ms, 0.14 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.16 ms, ci = <0.13 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.21 ms, ci = <0.15 ms, 0.27 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursiveTail
[info] - Factorial.recursiveTail.Test-11 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.12 ms, ci = <0.05 ms, 0.20 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.16 ms, ci = <-0.03 ms, 0.34 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.12 ms, ci = <0.09 ms, 0.16 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.17 ms, ci = <0.15 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.23 ms, ci = <0.19 ms, 0.26 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.25 ms, ci = <0.18 ms, 0.32 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.28 ms, ci = <0.21 ms, 0.36 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.32 ms, ci = <0.17 ms, 0.46 ms>, significance = 1.0E-10)
Darren Weber
  • 1,537
  • 19
  • 20
1

I know everyone already answered the question, but I thought that I might add this one optimization: If you convert the pattern matching to simple if-statements, it can speed up the tail recursion.

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var acc: Out = 1
    var i = 1
    while (i <= n) {
      acc = i * acc
      i += 1
    }
    acc
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)

    fac(1, 1)
  }

  def attempt(f: ()=>Unit): Boolean = {
    try {
        f()
        true
    } catch {
        case _: Throwable =>
            println(" <<<<< Failed...")
            false
    }
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
      showOutput match {
        case true =>
            val res = data._2.toString
            val pref = res.substring(0,5)
            val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
            val suff = res.substring(res.length-5)
            printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
        case false => 
            printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
    attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
    attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
    attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
    attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
  }
}

My results:

scala> Factorial.comparePerformance(20000)
By for loop returned 18192...5708616582...00000 in 179 ms
By while loop returned 18192...5708616582...00000 in 159 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 18192...5708616582...00000 in 169 ms
By tail recursion upward returned 18192...5708616582...00000 in 174 ms
By for loop returned 18192...5708616582...00000 in 212 ms
By while loop returned 18192...5708616582...00000 in 156 ms
By non-tail recursion returned 18192...5708616582...00000 in 155 ms
By tail recursion returned 18192...5708616582...00000 in 166 ms
By tail recursion upward returned 18192...5708616582...00000 in 137 ms
scala> Factorial.comparePerformance(200000)
By for loop returned 14202...0169293868...00000 in 17467 ms
By while loop returned 14202...0169293868...00000 in 17303 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 14202...0169293868...00000 in 18477 ms
By tail recursion upward returned 14202...0169293868...00000 in 17188 ms