5

Suppose a function is looped to produce a numeric result. The looping is stopped either if the iterations maximum is reached or the "optimality" condition is met. In either case, the value from the current loop is output. What is a functional way to get both this result and the stopping reason?

For illustration, here's my Scala implementation of the "Square Roots" example in 4.1 of https://www.cs.kent.ac.uk/people/staff/dat/miranda/whyfp90.pdf.

object SquareRootAlg {
    def next(a: Double)(x: Double): Double = (x + a/x)/2
    def repeat[A](f: A=>A, a: A): Stream[A] = a #:: repeat(f, f(a))

    def loopConditional[A](stop: (A, A) => Boolean)(s: => Stream[A] ): A = s match {
          case a #:: t  if t.isEmpty => a
          case a #:: t => if (stop(a, t.head)) t.head else loopConditional(stop)(t)}  
  }

Eg, to find the square root of 4:

import SquareRootAlg._
val cond = (a: Double, b: Double) => (a-b).abs < 0.01
val alg = loopConditional(cond) _
val s = repeat(next(4.0), 4.0)

alg(s.take(3))  // = 2.05, "maxIters exceeded"
alg(s.take(5)) // = 2.00000009, "optimality reached"

This code works, but doesn't give me the stopping reason. So I'm trying to write a method

 def loopConditionalInfo[A](stop: (A, A)=> Boolean)(s: => Stream[A]):  (A, Boolean) 

outputting (2.05, false) in the first case above, and (2.00000009, true) in the second. Is there a way to write this method without modifying the next and repeat methods? Or would another functional approach work better?

schrödingcöder
  • 565
  • 1
  • 9
  • 18

2 Answers2

4

Typically, you need to return a value that includes both a stopping reason and the result. Using the (A, Boolean) return signature you propose allows for this.

Your code would then become:

import scala.annotation.tailrec

object SquareRootAlg {
  def next(a: Double)(x: Double): Double = (x + a/x)/2
  def repeat[A](f: A=>A, a: A): Stream[A] = a #:: repeat(f, f(a))

  @tailrec // Checks function is truly tail recursive.
  def loopConditional[A](stop: (A, A) => Boolean)(s: => Stream[A] ): (A, Boolean) = {
    val a = s.head
    val t = s.tail
    if(t.isEmpty) (a, false)
    else if(stop(a, t.head)) (t.head, true)
    else loopConditional(stop)(t)
  }
}
Mike Allen
  • 8,139
  • 2
  • 24
  • 46
  • OP's original code already is tail-recursive. And it's also reasonably clear. I honestly don't see what exactly necessitates a complete rewrite of the method body. – Andrey Tyukin Sep 20 '18 at 20:40
  • 2
    @AndreyTyukin You're entitled to your opinion. :-) The primary benefit of `@tailrec` is that it complains if the function isn't tail recursive: so it both states an intention and verifies that it is true. As for the body rewrite, it's a trivial change: the original `case`-based version considers the same elements twice and yet still has the same `if` statements. The version I used was just a little terser. – Mike Allen Sep 20 '18 at 20:45
  • 1
    Just took a look at the decompiled code generated by the `case-case`-version: it seems to merge the two cases (with guard and without guard) together, so that the non-emptyness of `s` is checked only once, and the `head` and `tail` are also accessed only once, so the fact that it "considers the same elements twice" shouldn't have much impact on performance. Checking whether `s` matches `a :: t` pattern in the first place could cost some cycles, though, so your version might indeed be faster. – Andrey Tyukin Sep 20 '18 at 20:54
1

Just return the booleans without modifying anything else:

object SquareRootAlg {
  def next(a: Double)(x: Double): Double = (x + a/x)/2
  def repeat[A](f: A => A, a: A): Stream[A] = a #:: repeat(f, f(a))

  def loopConditionalInfo[A]
    (stop: (A, A)=> Boolean)
    (s: => Stream[A])
  : (A, Boolean) = s match {
    case a #:: t if t.isEmpty => (a, false)
    case a #:: t => 
      if (stop(a, t.head)) (t.head, true) 
      else loopConditionalInfo(stop)(t)
  }
}

import SquareRootAlg._
val cond = (a: Double, b: Double) => (a-b).abs < 0.01
val alg = loopConditionalInfo(cond) _
val s = repeat(next(4.0), 4.0)

println(alg(s.take(3))) // = 2.05, "maxIters exceeded"
println(alg(s.take(5)))

prints

(2.05,false)
(2.0000000929222947,true)
Andrey Tyukin
  • 43,673
  • 4
  • 57
  • 93