2

I have an expensive function which I want to run as few times as possible with the following requirement:

  • I have several input values to try
  • If the function returns a value below a given threshold, I don't want to try other inputs
  • if no result is below the threshold, I want to take the result with the minimal output

I could not find a nice solution using Iterator's takeWhile/dropWhile, because I want to have the first matching element included. just ended up with the following solution:

val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)

def expensiveFunc(s:String) : Double = {
  pseudoResult(s)
}

val inputsToTry = Seq("a","b","c")

val inputIt = inputsToTry.iterator
val results = mutable.ArrayBuffer.empty[(String, Double)]

val earlyAbort = 0.5 // threshold

breakable {
  while (inputIt.hasNext) {
    val name = inputIt.next()
    val res = expensiveFunc(name)
    results += Tuple2(name,res)
    if (res<earlyAbort) break()
  }
}

println(results) // ArrayBuffer((a,0.6), (b,0.2))

val (name, bestResult) = results.minBy(_._2) // (b, 0.2)

If i set val earlyAbort = 0.1, the result should still be (b, 0.2) without evaluating all the cases again.

Raphael Roth
  • 26,751
  • 15
  • 88
  • 145

5 Answers5

3

You can make use of Stream to achieve what you are looking for, remember Stream is some kind of lazy collection, that evaluate operations on demand.

Here is the scala Stream documentation.

You only need to do this:

val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)
val earlyAbort = 0.5

def expensiveFunc(s: String): Double = {
  println(s"Evaluating for $s")
  pseudoResult(s)
}

val inputsToTry = Seq("a","b","c")

val results = inputsToTry.toStream.map(input => input -> expensiveFunc(input))
val finalResult = results.find { case (k, res) => res < earlyAbort }.getOrElse(results.minBy(_._2))

If find does not get any value, you can use the same stream to find the min, and the function is not evaluated again, this is because of memoization:

The Stream class also employs memoization such that previously computed values are converted from Stream elements to concrete values of type A

Consider that this code will fail if the original collection was empty, if you want to support empty collections you should replace minBy with sortBy(_._2).headOption and getOrElse by orElse:

val finalResultOpt = results.find { case (k, res) => res < earlyAbort }.orElse(results.sortBy(_._2).headOption)

And the output for this is:

Evaluating for a

Evaluating for b

finalResult: (String, Double) = (b,0.2)

finalResultOpt: Option[(String, Double)] = Some((b,0.2))

  • @proximator you are wrong, try it. it's lazy because of the Stream. I've edited my original post adding the output, it's crearly not evaluating c. – Gonzalo Guglielmo Aug 09 '18 at 16:44
  • 1
    @proximator You are wrong again. Seriously, try it. when the function in map is evaluated for an element in the stream, it will no longer be evaluated again. Please try my code before commenting, or at least try to understand how a stream works. – Gonzalo Guglielmo Aug 09 '18 at 17:26
  • this is really interesting, I also thought that that this will evaulate everything twice of no result is below earlyAbort. – Raphael Roth Aug 10 '18 at 06:07
  • Nice! Does this mean that it memoizes the evaluated parts of the collection? – stefanobaghino Aug 10 '18 at 08:01
  • 1
    you should add a quote from the ScalaDoc which says `The Stream class also employs memoization such that previously computed values are converted from Stream elements to concrete values of type A` – Raphael Roth Aug 10 '18 at 11:40
  • thanks! I've edited the answer adding that quote! @stefanobaghino that's right. – Gonzalo Guglielmo Aug 10 '18 at 14:21
1

This is one of the use-cases for tail-recursion:

  import scala.annotation.tailrec
  val pseudoResult = Map("a" -> 0.6,"b" -> 0.2, "c" -> 1.0)

  def expensiveFunc(s:String) : Double = {
    pseudoResult(s)
  }

  val inputsToTry = Seq("a","b","c")

  val earlyAbort = 0.5 // threshold

  @tailrec
  def f(s: Seq[String], result: Map[String, Double] = Map()): Map[String, Double] = s match {
    case Nil => result
    case h::t =>
      val expensiveCalculation = expensiveFunc(h)
      val intermediateResult = result + (h -> expensiveCalculation)
      if(expensiveCalculation < earlyAbort) {
        intermediateResult
      } else {
        f(t, intermediateResult)
      }
  }
  val result = f(inputsToTry)

  println(result) // Map(a -> 0.6, b -> 0.2)

  val (name, bestResult) = f(inputsToTry).minBy(_._2) // ("b", 0.2)
curious
  • 2,908
  • 15
  • 25
1

The clearest, simplest, thing to do is fold over the input, passing forward only the current best result.

val inputIt :Iterator[String] = inputsToTry.iterator
val earlyAbort = 0.5 // threshold

inputIt.foldLeft(("",Double.MaxValue)){ case (low,name) =>
  if (low._2 < earlyAbort) low
  else Seq(low, (name, expensiveFunc(name))).minBy(_._2)
}
//res0: (String, Double) = (b,0.2)

It calls on expensiveFunc() only as many times as is needed, but it does walk through the entire input iterator. If that's still too onerous (lots of input) then I'd go with a tail-recursive method.

val inputIt :Iterator[String] = inputsToTry.iterator
val earlyAbort = 0.5 // threshold

def bestMin(low :(String,Double) = ("",Double.MaxValue)) :(String,Double) = {
  if (inputIt.hasNext) {
    val name = inputIt.next()
    val res = expensiveFunc(name)
    if (res < earlyAbort) (name, res)
    else if (res < low._2) bestMin((name,res))
    else bestMin(low)
  } else low
}
bestMin()  //res0: (String, Double) = (b,0.2)
jwvh
  • 50,871
  • 7
  • 38
  • 64
0

Use view in your input list: try the following:

  val pseudoResult = Map("a" -> 0.6, "b" -> 0.2, "c" -> 1.0)

  def expensiveFunc(s: String): Double = {
    println(s"executed for ${s}")
    pseudoResult(s)
  }

  val inputsToTry = Seq("a", "b", "c")
  val earlyAbort = 0.5 // threshold

  def doIt(): List[(String, Double)] = {

    inputsToTry.foldLeft(List[(String, Double)]()) {
      case (n, name) =>


        val res = expensiveFunc(name)
        if(res < earlyAbort) {
          return n++List((name, res))
        }
        n++List((name, res))
    }

  }

  val (name, bestResult) = doIt().minBy(_._2)
  println(name)
  println(bestResult)

The output:

executed for a
executed for b
b
0.2

As you can see, only a and b are evaluated, and not c.

proximator
  • 687
  • 6
  • 18
  • No this does not work, because I cannot garantee that there exists `res < earlyAbort`, in this case I need to have the smallest `res`, so I still need an external state whith the results of all computations – Raphael Roth Aug 09 '18 at 13:50
  • You want to break when the first time when you find a result < earlyAbort. So you will have at most one single result – proximator Aug 09 '18 at 13:54
  • yes, but maybe I have no result < earlyAbort, thus `find`will return None, in this case I need to take the element with the minimal value. – Raphael Roth Aug 09 '18 at 14:23
  • Ok, I see your issue now. I've edited the code. Can you try again please? – proximator Aug 09 '18 at 15:23
0

If you implement takeUntil and use it, you'd still have to go through the list once more to get the lowest one if you don't find what you are looking for. Probably a better approach would be to have a function that combines find with reduceOption, returning early if something is found or else returning the result of reducing the collection to a single item (in your case, finding the smallest one).

The result is comparable with what you could achieve using a Stream, as highlighted in a previous reply, but avoids leveraging memoization, which can be cumbersome for very large collections.

A possible implementation could be the following:

import scala.annotation.tailrec

def findOrElse[A](it: Iterator[A])(predicate: A => Boolean,
                                   orElse: (A, A) => A): Option[A] = {
  @tailrec
  def loop(elseValue: Option[A]): Option[A] = {
    if (!it.hasNext) elseValue
    else {
      val next = it.next()
      if (predicate(next)) Some(next)
      else loop(Option(elseValue.fold(next)(orElse(_, next))))
    }
  }
  loop(None)
}

Let's add our inputs to test this:

def f1(in: String): Double = {
  println("calling f1")
  Map("a" -> 0.6, "b" -> 0.2, "c" -> 1.0, "d" -> 0.8)(in)
}

def f2(in: String): Double = {
  println("calling f2")
  Map("a" -> 0.7, "b" -> 0.6, "c" -> 1.0, "d" -> 0.8)(in)
}

val inputs = Seq("a", "b", "c", "d")

As well as our helpers:

def apply[IN, OUT](in: IN, f: IN => OUT): (IN, OUT) =
  in -> f(in)

def threshold[A](a: (A, Double)): Boolean =
  a._2 < 0.5

def compare[A](a: (A, Double), b: (A, Double)): (A, Double) =
  if (a._2 < b._2) a else b

We can now run this and see how it goes:

val r1 = findOrElse(inputs.iterator.map(apply(_, f1)))(threshold, compare)
val r2 = findOrElse(inputs.iterator.map(apply(_, f2)))(threshold, compare)
val r3 = findOrElse(Map.empty[String, Double].iterator)(threshold, compare)

r1 is Some(b, 0.2), r2 is Some(b, 0.6) and r3 is (reasonably) None. In the first case, since we use a lazy iterator and terminate early, we only invoke f1 twice.

You can have a look at the results and can play with this code here on Scastie.

stefanobaghino
  • 11,253
  • 4
  • 35
  • 63