2

I have done a regression tree with rpart to assess the walking of elderly people based on a few variables. With the use of the plot I would like to use the output for further analysis in another software. However I was wondering whether it would be possible not only derive the walking per group from the leaf nodes, but also to derive the standard deviations from the leaf nodes (in terms of walking)?

image of my regression tree

#### Decision tree with rpart
modelRT <- rpart(logwalkin~.-walkinmin-walkingtime, data=trainDF,
             control=rpart.control(minsplit=25, maxdepth = 8, cp =0.00005))
rpart.plot(modelRT,type=3,digits=3,fallen.leaves=TRUE)
G5W
  • 36,531
  • 10
  • 47
  • 80
Joy
  • 93
  • 1
  • 12

2 Answers2

1

I don't think that you can do it from the plot, but you certainly can derive the standard deviations at each leaf node from the rpart model. Since you do not provide your data, I will make an example using the built-in iris data. Since you are interested in regression, I will eliminate the class variable (Species) and predict the variable Sepal.Length from the other variables.

Setup

library(rpart)
library(rpart.plot)

RP = rpart(Sepal.Length ~ ., data=iris[,-5])
rpart.plot(as.party(RP))

Rpart tree

As you can see, nodes 4,5,6,10,11,12 and 13 are the leaf nodes. Part of the returned structure RP$where tells you which leaf the original instances went to. So you just need to aggregate using this variable.

SD = aggregate(iris$Sepal.Length, list(RP$where), sd)
SD
  Group.1         x
1       4 0.2390221
2       5 0.2888391
3       6 0.2500526
4      10 0.4039577
5      11 0.3802046
6      12 0.3020486
7      13 0.2279132

Group.1 tells you which leaf node and x tells you the standard deviation for points that ended up in that leaf. If you wish to add the standard deviations to your plot, you could do that with mtext. After some fiddling with the placement:

rpart.plot(RP)
mtext(text=round(SD$x,1), side=1, line=3.2, at=seq(0.06,1,0.1505))

Plot with standard deviations

G5W
  • 36,531
  • 10
  • 47
  • 80
1

To plot the standard deviation at each node of the tree, you can use rpart.plot with a node.fun, as described in Chapter 6 of the rpart.plot package vignette. For example

library(rpart.plot)
data(iris)
tree = rpart(Sepal.Length~., data=iris, cp=.05) # example tree

# Calculate the standard deviation at each node of the tree.
sd <- sqrt(tree$frame$dev / (tree$frame$n-1))

# Append the standard deviation as an extra column to the tree frame.
tree$frame <- cbind(tree$frame, sd)

# Create a node.fun to print the standard deviation at each node.
# See Chapter 6 of the rpart.plot vignette http://www.milbo.org/doc/prp.pdf.
node.fun.sd <- function(x, labs, digits, varlen)
{
    s <- round(x$frame$sd, 2) # round sd to 2 digits
    paste(labs, "\n\nsd", s)
 }

# Plot the tree, using the node.fun to add the standard deviation to each node
rpart.plot(tree, type=4, node.fun=node.fun.sd)

which gives

plot

If you want the standard-deviations at just the leaf nodes (not the interior nodes), you can do this:

library(rpart.plot)
data(iris)
tree = rpart(Sepal.Length~., data=iris, cp=.05)
sd <- sqrt(tree$frame$dev / (tree$frame$n-1))
is.leaf <- tree$frame$var == "<leaf>" # logical vec, indexed on row in frame
sd[!is.leaf] <- NA # change sd of non-leaf nodes to NA
tree$frame <- cbind(tree$frame, sd)
node.fun2 <- function(x, labs, digits, varlen)
{
    s <- paste("\n\nsd", round(x$frame$sd, 2)) # round sd to 2 digits
    s[is.na(x$frame$sd)] <- "" # delete NAs
    paste(labs, s)
}
rpart.plot(tree, type=4, node.fun=node.fun2)

which gives

plot

Stephen Milborrow
  • 976
  • 10
  • 14