2

I am building a tree using the partykit R package, and I am wondering if there is a simple, efficient way to determine the depth number at each internal node. For example, the root node would have depth 0, the first two kid nodes have depth 1, the next kid nodes have depth 2, and so forth. This will eventually be used to calculate the minimal depth of a variable. Below is a very basic example (taken from vignette("constparty", package="partykit")):

library("partykit")
library("rpart")
data("Titanic", package = "datasets")
ttnc<-as.data.frame(Titanic)
ttnc <- ttnc[rep(1:nrow(ttnc), ttnc$Freq), 1:4]
names(ttnc)[2] <- "Gender"
rp <- rpart(Survived ~ ., data = ttnc)
ttncTree<-as.party(rp)
plot(ttncTree)

#This is one of my many attempts which does NOT work
internalNodes<-nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)]
depth(ttncTree)-unlist(nodeapply(ttncTree, ids=internalNodes, FUN=function(n){depth(n)}))

In this example, I want to output something similar to:

nodeid = 1 2 4 7 
depth  = 0 1 2 1

I apologize if my question is too specific.

2 Answers2

3

Here's a possible solution which should be efficient enough as usually the trees have no more than several dozens of nodes. I'm ignoring node #1, as it is always 0 an hence no point neither calculating it or showing it (IMO)

Inters <- nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)][-1]
table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(ttncTree, from = x)))))
# 2 4 7 
# 1 2 1 
David Arenburg
  • 91,361
  • 17
  • 137
  • 196
  • Thanks for the solution! Works perfectly for the example provided. Unfortunately, I am growing an unpruned tree with around 5000 nodes from a large dataset (eventually building a random forest), so this approach takes too long (around 15min for a single tree). I greatly appreciate the solution and will think about figuring out a faster method. – Peter Calhoun Feb 08 '16 at 19:41
  • I"ll wait for @AchimZeileis response, I'd bet he can figure out something better. – David Arenburg Feb 09 '16 at 08:31
  • Sorry for coming in so late and thanks to @David for his solution. My first attempt would have been: `depth(ttncTree) - unlist(nodeapply(ttncTree, ids = nodeids(ttncTree), depth))`. One could also restrict the `ids` to the inner nodes only as @David did. But my guess is that this is inefficient for very large trees because you traverse the tree for each `depth` call in the apply. So possibly you have to cycle through `ttncTree$node` manually and record all depth values. I think we haven't got anything in `partykit` that does exactly what you are looking for out of the box _and_ efficiently. – Achim Zeileis Feb 12 '16 at 23:09
  • Thanks for the comments and solution. I am currently using David's solution, but will come back to this problem in a few weeks. I'm thinking a fast solution may parse the output ``as.character(ttncTree)[1]``. I'll repost if I come up with faster solution. Thanks again! – Peter Calhoun Feb 14 '16 at 01:06
0

I had to revisit this problem recently. Below is a function to determine the depth of each node. I count the depth based on the number of times a vertical line | appears running the print.party() function.

library(stringr)
idDepth <- function(tree) {
  outTree <- capture.output(tree)
  idCount <- 1
  depthValues <- rep(NA, length(tree))
  names(depthValues) <- 1:length(tree)
  for (index in seq_along(outTree)){
    if (grepl("\\[[0-9]+\\]", outTree[index])) {
      depthValues[idCount] <- str_count(outTree[index], "\\|")
      idCount = idCount + 1
    }
  }
  return(depthValues)
}

> idDepth(ttncTree)
1 2 3 4 5 6 7 8 9 
0 1 2 2 3 3 1 2 2

There definitely seems to be a simpler, faster solution, but this is faster than using the intersect() function. Below is an example of the computation time for a large tree (around 1,500 nodes)

# Compare computation time for large tree #
library(mlbench)
set.seed(470174)
dat <- data.frame(mlbench.friedman1(5000))
rp <- rpart(as.formula(paste0("y ~ ", paste(paste0("x.", 1:10), collapse=" + "))),
            data=dat, control = rpart.control(cp = -1, minsplit=3, maxdepth = 10))
partyTree <- as.party(rp)

> length(partyTree) #Number of splits
[1] 1503
> 
> # Intersect() computation time
> Inters <- nodeids(partyTree)[-nodeids(partyTree, terminal = TRUE)][-1]
> system.time(table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(partyTree, from = x))))))
   user  system elapsed 
  22.38    0.00   22.44 
> 
> # Proposed computation time
> system.time(idDepth(partyTree))
   user  system elapsed 
   2.38    0.00    2.38