3

Suppose I build a toy tree model with RPART, how can I get the depth of the tree?

library(rpart)
library(partykit)
fit=rpart(factor(am)~.,mtcars,control=rpart.control(cp=0,minsplit = 1))
plot(as.party(fit))

I know how to count the leaves, for binary tree, we can approximate the depth by number of leaves, but it is not directly the depth of the tree.

sum(fit$frame$var=="<leaf>")
hxd1011
  • 885
  • 2
  • 11
  • 23

2 Answers2

8

rpart has a unexported function tree.depth that gives the depth of each node in the vector of node numbers passed to it. Using data from the question:

nodes <- as.numeric(rownames(fit$frame))
max(rpart:::tree.depth(nodes))
## [1] 2
G. Grothendieck
  • 254,981
  • 17
  • 203
  • 341
  • Thanks, this answer is perfect. But how did you know unexported function? you checked source code? – hxd1011 Dec 01 '16 at 14:25
  • Yes. One has to check the source. https://github.com/cran/rpart/blob/4a009f14f2b342baa2df55854d578a45b16a17da/R/zzz.R – G. Grothendieck Dec 01 '16 at 14:48
  • if you checked source, could you see if you can answer this question? http://stats.stackexchange.com/questions/248706/why-i-cannot-achieve-100-accuracy-in-my-simple-training-data-with-cart-model – hxd1011 Dec 01 '16 at 14:50
  • @hxd1011, Hi, could you please help me with me question, https://stackoverflow.com/questions/49228786/rpart-find-number-of-leaves-that-a-cp-value-to-pruning-a-tree-would-return Thank you!! – user1412 Mar 12 '18 at 06:42
0

Given the way splits from node n are named (2 * n and 2 * n * 1), vignette page 26, tree depth can be obtained truncating the log in base 2 of the max node:

library(rpart)

fit <- rpart(factor(am)~., mtcars, control=rpart.control(cp = 0, minsplit = 1))
plot(fit, margin = 0.1)
text(fit, digits = 3, all = TRUE, use.n = TRUE, cex = 0.8, pretty = TRUE)

row.names(fit$frame) |> 
  as.integer() |> 
  max() |> 
  log(base = 2) |> 
  trunc()
#> [1] 2

As a matter of fact, rpart:::tree.depth() uses a similar idea to compute the depth of each node

rpart:::tree.depth
#> function (nodes) 
#> {
#>     depth <- floor(log(nodes, base = 2) + 1e-07)
#>     depth - min(depth)
#> }
#> <bytecode: 0x000002b48f733c48>
#> <environment: namespace:rpart>

Created on 2023-08-14 with reprex v2.0.2

josep maria porrà
  • 1,198
  • 10
  • 18