9

I have a question about writing recursive algorithms in a functional style. I will use Scala for my example here, but the question applies to any functional language.

I am doing a depth-first enumeration of an n-ary tree where each node has a label and a variable number of children. Here is a simple implementation that prints the labels of the leaf nodes.

case class Node[T](label:T, ns:Node[T]*)
def dfs[T](r:Node[T]):Seq[T] = {
    if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n)) yield c
}
val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r) // returns Seq[Symbol] = ArrayBuffer('d, 'f, 'c)

Now say that sometimes I want to be able to give up on parsing oversize trees by throwing an exception. Is this possible in a functional language? Specifically is this possible without using mutable state? That seems to depend on what you mean by "oversize". Here is a purely functional version of the algorithm that throws an exception when it tries to handle a tree with a depth of 3 or greater.

def dfs[T](r:Node[T], d:Int = 0):Seq[T] = {
    require(d < 3)
    if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n, d+1)) yield c
}

But what if a tree is oversized because it is too broad rather than too deep? Specifically what if I want to throw an exception the n-th time the dfs() function is called recursively regardless of how deep the recursion goes? The only way I can see how to do this is to have a mutable counter that is incremented with each call. I can't see how to do it without a mutable variable.

I'm new to functional programming and have been working under the assumption that anything you can do with mutable state can be done without, but I don't see the answer here. The only thing I can think to do is write a version of dfs() that returns a view over all the nodes in the tree in depth-first order.

dfs[T](r:Node[T]):TraversableView[T, Traversable[_]] = ...

Then I could impose my limit by saying dfs(r).take(n), but I don't see how to write this function. In Python I'd just create a generator by yielding nodes as I visited them, but I don't see how to achieve the same effect in Scala. (Scala's equivalent to a Python-style yield statement appears to be a visitor function passed in as a parameter, but I can't figure out how to write one of these that will generate a sequence view.)

EDIT Getting close to the answer.

Here is an function that returns a Stream of nodes in depth-first order.

def dfs[T](r: Node[T]): Stream[Node[T]] = {
    (r #:: Stream.empty /: r.ns)(_ ++ dfs(_))
}

That is almost it. The only problem is that Stream memoizes all results, which is a waste of memory. I want a traversable view. The following is the idea, but does not compile.

def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] = {
    (Traversable(r).view /: r.ns)(_ ++ dfs(_))
}

It gives a "found TraversableView[Node[T], Traversable[Node[T]]], required TraversableView[Node[T], Traversable[_]] error for the ++ operator. If I change the return type to TraversableView[Node[T], Traversable[_]], I get the same problem with the "found" and "required" clauses switched. So there's some magic type variance incantation I haven't lit upon yet, but this is close.

W.P. McNeill
  • 16,336
  • 12
  • 75
  • 111

3 Answers3

7

It can be done: you just have to write some code to actually iterate through the children in the way you want (as opposed to relying on for).

More explicitly, you'll have to write code to iterate through a list of children and check if the "depth" crossed your threshold. Here's some Haskell code (I'm really sorry, I'm not fluent in Scala, but this can probably be easily transliterated):

http://ideone.com/O5gvhM

In this code, I've basically replaced the for loop for an explicit recursive version. This allows me to stop the recursion if the number of visited nodes is already too deep (i.e., limit is not positive). When I recurse to examine the next child, I subtract the number of nodes the dfs of the previous child visited and set this as the limit for the next child.

Functional languages are fun, but they're a huge leap from imperative programming. It really makes you pay attention to the concept of state, because all of it is excruciatingly explicit in the arguments when you go functional.

EDIT: Explaining this a bit more.

I ended up converting from "print just the leaf nodes" (which was the original algorithm from the OP) to "print all nodes". This enabled me to have access to the number of nodes the subcall visited through the length of the resulting list. If you want to stick to the leaf nodes, you'll have to carry around how many nodes you have already visited:

http://ideone.com/cIQrna

EDIT again To clear up this answer, I'm putting all the Haskell code on ideone, and I've transliterated my Haskell code to Scala, so this can stay here as the definite answer to the question:

case class Node[T](label:T, children:Seq[Node[T]])

case class TraversalResult[T](num_visited:Int, labels:Seq[T])

def dfs[T](node:Node[T], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit => 
            node.children match {
                case Nil => TraversalResult(1, List(node.label))
                case children => {
                    val result = traverse(node.children, limit - 1)
                    TraversalResult(result.num_visited + 1, result.labels)
                }
            }
    }

def traverse[T](children:Seq[Node[T]], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit =>
            children match {
                case Nil => TraversalResult(0, Nil)
                case first :: rest => {
                    val trav_first = dfs(first, limit)
                    val trav_rest = 
                        traverse(rest, limit - trav_first.num_visited)
                    TraversalResult(
                        trav_first.num_visited + trav_rest.num_visited,
                        trav_first.labels ++ trav_rest.labels
                    )
                }
            }
    }

val n = Node(0, List(
    Node(1, List(Node(2, Nil), Node(3, Nil))),
    Node(4, List(Node(5, List(Node(6, Nil))))),
    Node(7, Nil)
))
for (i <- 1 to 8)
    println(dfs(n, i))

Output:

TraversalResult(1,List())
TraversalResult(2,List())
TraversalResult(3,List(2))
TraversalResult(4,List(2, 3))
TraversalResult(5,List(2, 3))
TraversalResult(6,List(2, 3))
TraversalResult(7,List(2, 3, 6))
TraversalResult(8,List(2, 3, 6, 7))

P.S. this is my first attempt at Scala, so the above probably contains some horrid non-idiomatic code. I'm sorry.

Cesar Kawakami
  • 223
  • 1
  • 7
  • Haskell is a good choice for the solution, since it doesn't give you the option of using mutable state. If I think I should be able to figure out how to implement the Scala version of your `traverse` function, but Scala has the extra difficulty of ensuring that the sequence returned by `dfs` is evaluated lazily, since unlike Haskell Scala does strict evaluation by default. (The print-all-the-nodes strategy is what I was getting at in the edit I made to my original question in which I discuss Python-style generators.) – W.P. McNeill Nov 21 '12 at 01:09
  • I'm not sure I understand. I might be wrong, but the algorithm I posted doesn't rely on laziness for it to work. I've tested it locally by forcing strictness (adding `!`s all around the code) and it still worked. Can you elaborate on which specific point of the code is giving problems? Perhaps I can help you. – Cesar Kawakami Nov 21 '12 at 01:19
  • I'm sure your algorithm works. The issue is not that it relies on laziness, but that Haskell gives you laziness for free whereas in Scala you have to explicitly code for it. There's no point in an early exit if you have to enumerate the whole tree before doing it, so I'd have to do extra work to create a Scala version of what you wrote. – W.P. McNeill Nov 21 '12 at 01:25
  • 1
    The `var`'s should be `val`'s, apart from that it looks ok. – Ivan Meredith Nov 21 '12 at 04:29
4

You can convert breadth into depth by passing along an index or taking the tail:

def suml(xs: List[Int], total: Int = 0) = xs match {
  case Nil => total
  case x :: rest => suml(rest, total+x)
}

def suma(xs: Array[Int], from: Int = 0, total: Int = 0) = {
  if (from >= xs.length) total
  else suma(xs, from+1, total + xs(from))
}

In the latter case, you already have something to limit your breadth if you want; in the former, just add a width or somesuch.

Rex Kerr
  • 166,841
  • 26
  • 322
  • 407
  • Aren't these equivalent to `List(1,2,3).sum` and `Array(1,2,3).sum`? I don't see how to incorporate them into my recursive algorithm. – W.P. McNeill Nov 21 '12 at 00:32
  • @W.P.McNeill - You know how to limit depth-first searching but not breadth-first. I showed you how to get a parameter that specifies how broadly you've searched. What is missing? – Rex Kerr Nov 21 '12 at 07:52
  • I'm not sure how to call `suml` and/or `suma` from `dfs`. But never mind. With your help about type casting the `++` operator in the related question, I think I've got an appropriately Scalaesque technique for lazy depth-first search. – W.P. McNeill Nov 21 '12 at 19:49
2

The following implements a lazy depth-first search over nodes in a tree.

import collection.TraversableView
case class Node[T](label: T, ns: Node[T]*)
def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] =
  (Traversable[Node[T]](r).view /: r.ns) {
    (a, b) => (a ++ dfs(b)).asInstanceOf[TraversableView[Node[T], Traversable[Node[T]]]]
  }

This prints the labels of all the nodes in depth-first order.

val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd, 'e, 'f, 'c)

This does the same thing, quitting after 3 nodes have been visited.

dfs(r).take(3).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd)

If you want only leaf nodes you can use filter, and so forth.

Note that the fold clause of the dfs function requires an explicit asInstanceOf cast. See "Type variance error in Scala when doing a foldLeft over Traversable views" for a discussion of the Scala typing issues that necessitate this.

Community
  • 1
  • 1
W.P. McNeill
  • 16,336
  • 12
  • 75
  • 111