4

I have a tree structure I'm receiving from a Java library. I am trying to flatten it since I'm interested only in the "key" values of the tree. The tree is made up of zero or more of the following classes:

class R(val key: String, val nodes: java.util.List[R]) {}

with an empty nodes list representing the end of a branch. A sample can be build via this code:

val sample =  List[R](
  new R("1",  List[R](
    new R("2",  List[R]().asJava),
    new R("3",  List[R](new R("4",  List[R]().asJava))
      .asJava)).asJava)).asJava

I am having trouble writing both a correct method, and an efficient method. This is what I have so far:

def flattenTree(tree: List[R]): List[String] = {
  tree.foldLeft(List[String]())((acc, x) => 
             x.key :: flattenTree(x.nodes.asScala.toList))
}

However when I run this code, as inefficient as it may be, I still get it incorrect. My result ends up being:

>>> flattenTree(sample.asScala.toList)
res0: List[String] = List(1, 3, 4)

which means for some reason I lost the node with key "2".

Can someone recommend a correct and more efficient way of flattening this tree?

Chris Martin
  • 30,334
  • 10
  • 78
  • 137
Will I Am
  • 2,614
  • 3
  • 35
  • 61

4 Answers4

4

You are failing to add in the accumulated keys on each successive call. Try the following:

def flattenTree(tree: List[R]): List[String] = {
  tree.foldLeft(List[String]())((acc, x) =>
             x.key :: flattenTree(x.nodes.asScala.toList) ++ acc)
}

which generates the result: List(1, 3, 4, 2),

or, if proper ordering is important:

def flattenTree(tree: List[R]): List[String] = {
  tree.foldLeft(List[String]())((acc, x) =>
             acc ++ (x.key :: flattenTree(x.nodes.asScala.toList)))
}

which generates the result: List(1, 2, 3, 4)

Shadowlands
  • 14,994
  • 4
  • 45
  • 43
  • Thanks, that gets me over the hump. I can continue and hopefully someone later will come up with a suggestion for better way of doing this. – Will I Am Sep 06 '15 at 08:26
4

You can define a function to flatten an R object with flatMap:

// required to be able to use flatMap on java.util.List
import scala.collection.JavaConversions._

def flatten(r: R): Seq[String] = {
  r.key +: r.nodes.flatMap(flatten)
}

And a function to flatten a sequence of those:

def flattenSeq(l: Seq[R]): Seq[String] = l flatMap flatten

r.nodes.flatMap(flatten) is a Buffer, so prepending to it is not efficient. It becomes quadratic complexity. So, if the order is not important is more efficient to append: def flatten(r: R): Seq[String] = r.nodes.flatMap(flatten) :+ r.key

Kolmar
  • 14,086
  • 1
  • 22
  • 25
  • I am not sure, maybe I'm something, but the children are java.util.List which do not support flatMap. I'd have to do a conversion asJava again? So flatten's body would become "r.key +: r.nodes.asScala.toSeq.flatMap(flatten3)" ? – Will I Am Sep 06 '15 at 17:01
  • @WillIAm Oh, sorry, I forgot to include the import to my answer. – Kolmar Sep 06 '15 at 17:28
  • Thanks! This JavaConversions._ vs JavaConverters._ is very confusing. :) I had the latter. – Will I Am Sep 06 '15 at 18:02
1

Convert each R to a Scalaz Tree, and call flatten to do a pre-order traversal.

import scala.collection.JavaConversions._
import scalaz._

def rTree(r: R): Tree[String] =
  Tree.node(r.key, r.nodes.toStream.map(rTree))

sample.flatMap(r => rTree(r).flatten): Seq[String]
// List(1, 2, 3, 4)

Edit: Unfortunately, due to a bug in scalaz as of version 7.1.1, this causes a stack overflow for wide trees.

Chris Martin
  • 30,334
  • 10
  • 78
  • 137
  • Hmm, I tried your suggestion and increased my nodes to 10000 items, but I ended up with a stack overflow (somewhere between 1000 and 10000). The only change from the code above was to change List[String] to Seq[String] to make the compiler happy. I was going to go with this approach since it seemed a bit faster than the one suggested above by Kolmar. http://pastebin.com/BbCEm4H4 – Will I Am Sep 06 '15 at 18:33
  • The stack overflow seems to be a bug in scalaz :( – Chris Martin Sep 07 '15 at 03:24
1

What about using Streams like scalaz does:

def flatten(rootElem: R): Stream[String] = {
  def flatten0(elem: R, xs: Stream[String]): Stream[String] =
    Stream.cons(elem.key, elem.nodes.foldLeft(xs)((acc, x) => flatten0(x, acc)))

  flatten0(rootElem, Stream.empty)
}
Alexandr Nikitin
  • 7,258
  • 2
  • 34
  • 42