8

I would like to inspect all the observations that reached some node in an rpart decision tree. For example, in the following code:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit

n= 81 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

I would like to see all the observations in node (5) (i.e.: the 33 observations for which Start>=8.5 & Start< 14.5). Obviously I could manually get to them. But I would like to have some function like (say) "get_node_date". For which I could just run get_node_date(5) - and get the relevant observations.

Any suggestions on how to go about this?

MichaelChirico
  • 33,841
  • 14
  • 113
  • 198
Tal Galili
  • 24,605
  • 44
  • 129
  • 187

6 Answers6

5

There seems to be no such function which enables an extraction of the observations from a specific node. I would solve it as follows: first determine which rule/s is/are used for the node you are insterested in. You can use path.rpart for it. Then you could apply the rule/s one after the other to extract the observations.

This approach as a function:

get_node_date <- function(tree = fit, node = 5){
  rule <- path.rpart(tree, node)
  rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
  ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
  kyphosis[ind,]
  }

For node 5 you get:

get_node_date()

 node number: 5 
   root
   Start>=8.5
   Start< 14.5
   Kyphosis Age Number Start
2    absent 158      3    14
10  present  59      6    12
11  present  82      5    14
14   absent   1      4    12
18   absent 175      5    13
20   absent  27      4     9
23  present  96      3    12
26   absent   9      5    13
28   absent 100      3    14
32   absent 125      2    11
33   absent 130      5    13
35   absent 140      5    11
37   absent   1      3     9
39   absent  20      6     9
40  present  91      5    12
42   absent  35      3    13
46  present 139      3    10
48   absent 131      5    13
50   absent 177      2    14
51   absent  68      5    10
57   absent   2      3    13
59   absent  51      7     9
60   absent 102      3    13
66   absent  17      4    10
68   absent 159      4    13
69   absent  18      4    11
71   absent 158      5    14
72   absent 127      4    12
74   absent 206      4    10
77  present 157      3    13
78   absent  26      7    13
79   absent 120      2    13
81   absent  36      4    13
DatamineR
  • 10,428
  • 3
  • 25
  • 45
2

Terminal node assignments for training observations in rpart can be obtained from $where:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit$where

As a function:

get_node <- function(rpart.object=fit, data=kyphosis, node.number=5) {
  data[which(fit$where == node.number),]  
}
get_node()

This works for training observations only though, not for new observations. And not for inner nodes.

  • 1
    This works only the terminal nodes of the tree. The questions asks for non-terminal nodes. – riccardo-df Jan 24 '23 at 08:03
  • 1
    @riccardo-df True that! I've adjusted the answer. I've left the answer anyway, b/c some users might need only terminal nodes and this involves only a limited amount of code. Higher-ranking answers provide a more thorough answer, obviously. – Marjolein Fokkema Jan 24 '23 at 23:23
1

The partykit package also provides a canned solution for this. You just need to convert the rpart object to the party class in order to use its unified interface for dealing with trees. And then you can use the data_party() function.

Using the fit from the question and having loaded library("partykit") you can first coerce the rpart tree to party:

pfit <- as.party(fit)
plot(pfit)

full pfit tree

There are only two small nuisances for extracting the data in the way you want: (1) The model.frame() from the original fit is always dropped in the coercion and needs to be reattached manually. (2) A different numbering scheme is used for the nodes. You want node 4 (rather than 5) now.

pfit$data <- model.frame(fit)
data4 <- data_party(pfit, 4)
dim(data4)
## [1] 33  5
head(data4)
##    Kyphosis Age Start (fitted) (response)
## 2    absent 158    14        7     absent
## 10  present  59    12        8    present
## 11  present  82    14        8    present
## 14   absent   1    12        5     absent
## 18   absent 175    13        7     absent
## 20   absent  27     9        5     absent

Another route is to subset the subtree starting from node 4 and then taking the data from that:

pfit4 <- pfit[4]
plot(pfit4)

subtree of pfit from node 4

Then data_party(pfit4) gives you the same as data4 above. And pfit4$data gives you the data without the (fitted) node and the predicted (response).

Achim Zeileis
  • 15,710
  • 1
  • 39
  • 49
  • if you used `ptree$data <- model.frame(eval(tree$call$data))` the variables not used in the formula wouldnt be dropped – rawr Sep 26 '16 at 22:00
  • True...but only if `data` contains all variables in the `formula` which is not necessarily the case. With the `model.frame()` you also get transformed variables, e.g., `log()`, `Surv()` or `factor()` versions of variables that are often created on the fly. – Achim Zeileis Sep 26 '16 at 22:08
  • BTW: The `as.party()` coercion for `rpart` objects now _keeps the data_ by default! Thus, you can do `as.party(fit, data = TRUE)` (which is the new default) or `as.party(fit, data = FALSE)` (which corresponds to the old behavior). – Achim Zeileis Sep 26 '16 at 22:10
1

Yet another way, this works by finding all of the terminal nodes of any particular node and returning the subset of data used in the call.

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)

head(subset.rpart(fit, 5))
#    Kyphosis Age Number Start
# 2    absent 158      3    14
# 10  present  59      6    12
# 11  present  82      5    14
# 14   absent   1      4    12
# 18   absent 175      5    13
# 20   absent  27      4     9


subset.rpart <- function(tree, node = 1L) {
  data <- eval(tree$call$data, parent.frame(1L))
  wh <- sapply(as.integer(rownames(tree$frame)), parent)
  wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
  data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}

parent <- function(x) {
  if (x[1] != 1)
    c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}
rawr
  • 20,481
  • 4
  • 44
  • 78
0

rpart returns rpart.object element which contains the information you need:

require(rpart)
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit2

get_node_date <-function(nodeId,fit)
{  
  fit$frame[toString(nodeId),"n"]
}


for (i in c(1,2,4,5,10,11,22,23,3) )
  cat(get_node_date(i,fit2),"\n")
Guy s
  • 1,586
  • 3
  • 20
  • 27
  • 1
    You don't get the observations through this, but only the number of abservations which fall into a category – DatamineR Apr 20 '16 at 16:31
0

An alternative method consists in finding all children nodes from the given node n. We can use the rpart object to find those. Combining this information with the end node for each point in the dataset (kyphosis, in this question), obtained from fit$where as explained @rawar, you can get all points in the dataset involved in given node, not necessarily an end one.

A summary of the steps are:

  1. Find node numbers and identify those that are end nodes ("leaf"). This information can be found in the frame element of the rpart object.
  2. Compute all children nodes of the given node n. They can be computed recursively using the fact that the children of node n are 2*n and 2*n+1, as explained in the rpart.plot package vignette page 26
  3. Once the leaves hanging from the node n are known, pick the points in the dataset in those leaves using the where element of the rpart object

I coded steps 1 and 2 in function get_children_nodes() and step 3 in function get_node_data() that answers the question posed. In this function, i've included the possibility to print the corresponding node rule (rule = TRUE) to get the same answer than @datamineR

library(rpart)
library(rpart.plot)

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
get_children_nodes <- function(tree, node){
  # check if node is a leaf based in rpart object (tree) information (step 1)
  z <- tree$frame
  is_leaf <- z$var == "<leaf>"
  nodes <- as.integer(row.names(z))
  
  # find recursively all children nodes (step 2)
  find_children <- function(node, nodes, is_leaf){
    condition <- is_leaf[nodes == node]
    if (condition) {
      # If node is leaf, return it
      v1 <- node
    } else {
      # If node is not leaf, search children leaf recursively
      v1 <- c(find_children(2 * node, nodes, is_leaf), 
              find_children(2 * node + 1, nodes, is_leaf))
    } 
    return(v1)
  }
  
  return(find_children(node, nodes, is_leaf))
}
get_node_data <- function(dataset, tree, node, rule = FALSE) {
  # Find children nodes of the node
  children_nodes <- get_children_nodes(tree, node)
  # match those nodes into the rpart node identification
  id_nodes <- which(as.integer(row.names(tree$frame)) %in% children_nodes)
  # Get the elements in the datset involved in the node (step 3)
  filtered_dataset <- dataset[tree$where %in% id_nodes, ]
  
  # print the node rule if needed
  if(rule) {
    rpart::path.rpart(tree, node, pretty = TRUE)
    cat("  \n")
  }
  return( filtered_dataset)
}
# Get the children nodes
get_children_nodes(fit, 5)
#> [1] 10 22 23
# Complete function to return the elements of node 5
get_node_data(kyphosis, fit, 5, rule = TRUE) 
#> 
#>  node number: 5 
#>    root
#>    Start>=8.5
#>    Start< 14.5
#> 
#>    Kyphosis Age Number Start
#> 2    absent 158      3    14
#> 10  present  59      6    12
#> 11  present  82      5    14
#> 14   absent   1      4    12
#> 18   absent 175      5    13
#> 20   absent  27      4     9
#> 23  present  96      3    12
#> 26   absent   9      5    13
#> 28   absent 100      3    14
#> 32   absent 125      2    11
#> 33   absent 130      5    13
#> 35   absent 140      5    11
#> 37   absent   1      3     9
#> 39   absent  20      6     9
#> 40  present  91      5    12
#> 42   absent  35      3    13
#> 46  present 139      3    10
#> 48   absent 131      5    13
#> 50   absent 177      2    14
#> 51   absent  68      5    10
#> 57   absent   2      3    13
#> 59   absent  51      7     9
#> 60   absent 102      3    13
#> 66   absent  17      4    10
#> 68   absent 159      4    13
#> 69   absent  18      4    11
#> 71   absent 158      5    14
#> 72   absent 127      4    12
#> 74   absent 206      4    10
#> 77  present 157      3    13
#> 78   absent  26      7    13
#> 79   absent 120      2    13
#> 81   absent  36      4    13

Created on 2023-08-14 with reprex v2.0.2

josep maria porrà
  • 1,198
  • 10
  • 18