1

I'm trying to write a tail-recursive quicksort in Scala that works by building up a continuation, without the use of a trampoline. So far I have the following:

object QuickSort {

  def sort[A: Ordering](toSort: Seq[A]): Seq[A] = {
    val ordering = implicitly[Ordering[A]]
    import ordering._

    @scala.annotation.tailrec
    def step(list: Seq[A], conts: List[Seq[A] => Seq[A]]): Seq[A] = list match {
      case s if s.length <= 1 => conts.foldLeft(s) { case (acc, next) => next(acc) }
      case Seq(h, tail @ _*) => {
        val (less, greater) = tail.partition(_ < h)
        step(less, { sortedLess: Seq[A] =>
            /*
            Can't use 

            step(greater, sortedGreater => (sortedLess :+ h) ++ sortedGreater)

            and keep the tailrec annotation
           */
          (sortedLess :+ h) ++ sort(greater)
        } +: conts)
      }
    }

    step(toSort, Nil)
  }

}

Click for ScalaFiddle

On my computer, the above implementation works with a random sequence of at least 4000000 elements, but I have my doubts about it. Specifically, I would like to know:

  1. Is it stack-safe? Can we tell by just looking at the code? It compiles with @tailrec, but the call to sort(greater) seems a bit suspicious.
  2. If the answer to (1) is "No", is it possible to write a tail recursive quick sort in CPS-style in Scala that is, without using a trampoline? How ?

To be clear, I've looked at this related question that talks about how to implement a tail recursive quick sort using trampolines (which I know how to use) or your own explicit stack, but I specifically want to know if and how it can be done in a different way.

Community
  • 1
  • 1
lloydmeta
  • 1,289
  • 1
  • 15
  • 25

3 Answers3

1
  1. Your code is tail recursive, so should be stack-safe. The call to sort(greater) is parked in the continuation, it lives on the heap rather than the stack. Given a sufficiently large problem of the wrong shape, you might blow the heap, but that takes a lot more than blowing the stack.
Iadams
  • 536
  • 3
  • 7
  • Ah yes, I understand that I'm trading stack for heap and that given a sufficiently large problem, this CPS trick (and the related trampoline/free monad variations) will blow my heap. I guess intuitively, the fact that the `sort ` call is on the heap is safe makes sense; but my understanding of exactly why is a bit fuzzy and hand-wavy. – lloydmeta Sep 19 '16 at 11:55
0

No, your code is not stack-safe. sort calls step and step calls sort again in greater part, so it is not stack-safe.

To do cps, lets start from normal form:

def sort(list: Seq[A]): Seq[A] = list match {
  case s if s.length <= 1 => s
  case Seq(h, tail @ _*) => {
    val (less, greater) = tail.partition(_ < h)
    val l = sort(less)
    val g = sort(greater)
    (l :+ Seq(h)) ++ g
  }
}

Then translate it to cps, very straightforward:

def sort(list: Seq[A], cont: Seq[A] => Unit): Unit = list match {
  case s if s.length <= 1 => cont(s)
  case Seq(h, tail @ _*) => {
    val (less, greater) = tail.partition(_ < h)
    sort(less, { l =>
      sort(greater, { g => 
        cont((l :+ Seq(h)) ++ g)
      })
    })
  }
}

Note:

  • CPS function always return Unit
  • Continuation alweays return Unit
  • Every recursive call becomes call to self with remain statements wrapped in continuation.
  • Returns become call to continuation

Finally, wrap it to normal form:

def quicksort(list: Seq[A]): Seq[A] = {
  var result
  sort(list, { r => result = r })
  result
}

NOTE: The CPS transform makes every function tail-call (NOT tail-rec), as scala doesn't support tail-call optimize, so you need to do tail-call optimize manually:

trait TCF[T] {
  def result: Option[T]
  def apply(): TCF[T]
}
private def tco[T](f: => TCF[T]): TCF[T] = new TCF[T] {
  def result = None
  def apply() = f
}

def quicksort[A: Ordering](list: Seq[A]): Seq[A] = {
  case class Result(r: Seq[A]) extends Exception
  Iterator.iterate(sort(list, { r: Seq[A] =>
    new TCF[Seq[A]] {
      def result = Some(r)
      def apply() = throw new RuntimeException("unreachable")
    }
  }))(c => c()).dropWhile(_.result == None).next().result.get
}

private def sort[A: Ordering](list: Seq[A], cont: Seq[A] => TCF[Seq[A]]): TCF[Seq[A]] = {
  val ordering = implicitly[Ordering[A]]
  import ordering._
  list match {
    case s if s.length <= 1 => tco(cont(s))
    case Seq(h, tail@_*) => {
      val (less, greater) = tail.partition(_ < h)
      tco(sort(less, { l: Seq[A] =>
        tco(sort(greater, { g: Seq[A] =>
          tco(cont((l :+ h) ++ g))
        }))
      }))
    }
  }
}

Try it here.

Zang MingJie
  • 5,164
  • 1
  • 14
  • 27
  • Thanks for your answer. I managed to get it to compile after a few fixes (see https://scalafiddle.io/sf/OzInX1U/2), but unfortunately, noticed that it is neither stack-safe (on my computer it dies when given a Seq with length > 3000; in ScalaJS, this could be browser-dependent), nor tail-recursive (sort not in tail position; same problem I faced in my solution). Also, if possible, I would like the solution to be fully immutable (references and data structures). – lloydmeta Sep 19 '16 at 11:52
  • @lloydmeta I'm not sure CPS ensure stack-safe, but if you need stack-safe, you can weave the continuation out, which use a similar concept as trampoline. – Zang MingJie Sep 20 '16 at 06:58
  • @lloydmeta The CPS transform makes every function tail-call (NOT tail-rec), as scala doesn't support tail-call optimize, so you need to do tail-call optimize manually – Zang MingJie Sep 20 '16 at 09:20
  • Thanks for updating your code! It can now handle a lot more elements than before. BTW, it looks like your original code stack-overflowed because it was doing a non-tail-recursive call when traversing the tree to build the continuation; your new code takes care of this by making the calls "lazy" via trampolining. You're right though, CPS by itself does not ensure stack-safety in Scala, because passing the last value into the continuation can cause a stack overflow if each layer simply calls the previous one (cont'd). – lloydmeta Sep 20 '16 at 11:55
  • (cont'd) I've updated my question to take care of this by accumulating the continuations and folding over them once we reach the end; see https://scalafiddle.io/sf/xh0CMpu/1 – lloydmeta Sep 20 '16 at 11:56
0

I decided to use JVisualVM to take a look at the call tree for the implementation I had in the question, and found that it was eating up stack as a result of the ++ step(greater) invocation. I think it was just very difficult to get to the point where we would stack overflow there because the list was being divided each time by half, with the smaller half being sorted tail-recursively in a tail-recursive, stack-safe manner.

After thinking about this a bit, I came up with the following revised solution (try it out here)

object QuickSort {

  def sort[A: Ordering](toSort: Seq[A]): Seq[A] = {
    val ordering = implicitly[Ordering[A]]
    import ordering._

    // Aliasing allows us to be tail-recursive
    def step2(list: Seq[A], conts: Vector[Seq[A] => Seq[A]]): Seq[A] = step(list, conts)

    @scala.annotation.tailrec
    def step(list: Seq[A], conts: Vector[Seq[A] => Seq[A]]): Seq[A] = list match {
      case s if s.length <= 1 => conts.foldLeft(s) { case (acc, next) => next(acc) }
      case Seq(h, tail @ _*) => {
        val (less, greater) = tail.partition(_ < h)
        val nextConts: Vector[Seq[A] => Seq[A]] =
          { sortedLess: Seq[A] =>
            sortedLess :+ h
          } +: { appendedLess: Seq[A] =>
            step2(greater, Vector({ sortedGreater => appendedLess ++ sortedGreater }))
          } +: conts
        step(less, nextConts)
      }
    }
    step(toSort, Vector.empty)
  }

}

Main differences are:

  • Using a step2 alias for step to keep the @tailrec annotation happy.
  • Instead of invoking step(greater) in the continuation for sorting the less partition, we just add another continuation to be run into the conts accumulator, where we append the sorted less partition to the sorted greater partition. I suppose you could argue that this accumulator is just a stack on the heap ..

Interestingly enough, this solution turns out to be quite fast, beating the Scalaz trampolining solution in the linked question. Comparing it with the half-stack solution above, it was around 30 ns slower when sorting 1 million elements, but that was within error.

[info] Benchmark                             (sortLength)  Mode  Cnt     Score    Error  Units
[info] SortBenchmarks.sort                            100  avgt   30     0.034 ±  0.001  ms/op
[info] SortBenchmarks.sort                          10000  avgt   30     6.258 ±  0.072  ms/op
[info] SortBenchmarks.sort                        1000000  avgt   30  1016.849 ± 23.572  ms/op
[info] SortBenchmarks.scalazSort                      100  avgt   30     0.070 ±  0.001  ms/op
[info] SortBenchmarks.scalazSort                    10000  avgt   30    10.426 ±  0.092  ms/op
[info] SortBenchmarks.scalazSort                  1000000  avgt   30  1635.693 ± 68.068  ms/op
Community
  • 1
  • 1
lloydmeta
  • 1,289
  • 1
  • 15
  • 25