An alternative method consists in finding all children nodes from the given node n
.
We can use the rpart
object to find those. Combining this information with the
end node for each point in the dataset (kyphosis, in this question),
obtained from fit$where
as explained
@rawar, you can get all points in the dataset involved in given node, not
necessarily an end one.
A summary of the steps are:
- Find node numbers and identify those that are end nodes ("leaf"). This
information can be found in the
frame
element of the rpart object.
- Compute all children nodes of the given node
n
. They can be computed
recursively using the fact that the children of node n
are 2*n
and
2*n+1
, as explained in the rpart.plot
package
vignette page 26
- Once the leaves hanging from the node
n
are known, pick the points in the
dataset in those leaves using the where
element of the rpart object
I coded steps 1 and 2 in function get_children_nodes()
and step 3 in function
get_node_data()
that answers the question posed. In this function, i've
included the possibility to print the corresponding node rule (rule = TRUE
) to
get the same answer than @datamineR
library(rpart)
library(rpart.plot)
fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
get_children_nodes <- function(tree, node){
# check if node is a leaf based in rpart object (tree) information (step 1)
z <- tree$frame
is_leaf <- z$var == "<leaf>"
nodes <- as.integer(row.names(z))
# find recursively all children nodes (step 2)
find_children <- function(node, nodes, is_leaf){
condition <- is_leaf[nodes == node]
if (condition) {
# If node is leaf, return it
v1 <- node
} else {
# If node is not leaf, search children leaf recursively
v1 <- c(find_children(2 * node, nodes, is_leaf),
find_children(2 * node + 1, nodes, is_leaf))
}
return(v1)
}
return(find_children(node, nodes, is_leaf))
}
get_node_data <- function(dataset, tree, node, rule = FALSE) {
# Find children nodes of the node
children_nodes <- get_children_nodes(tree, node)
# match those nodes into the rpart node identification
id_nodes <- which(as.integer(row.names(tree$frame)) %in% children_nodes)
# Get the elements in the datset involved in the node (step 3)
filtered_dataset <- dataset[tree$where %in% id_nodes, ]
# print the node rule if needed
if(rule) {
rpart::path.rpart(tree, node, pretty = TRUE)
cat(" \n")
}
return( filtered_dataset)
}
# Get the children nodes
get_children_nodes(fit, 5)
#> [1] 10 22 23
# Complete function to return the elements of node 5
get_node_data(kyphosis, fit, 5, rule = TRUE)
#>
#> node number: 5
#> root
#> Start>=8.5
#> Start< 14.5
#>
#> Kyphosis Age Number Start
#> 2 absent 158 3 14
#> 10 present 59 6 12
#> 11 present 82 5 14
#> 14 absent 1 4 12
#> 18 absent 175 5 13
#> 20 absent 27 4 9
#> 23 present 96 3 12
#> 26 absent 9 5 13
#> 28 absent 100 3 14
#> 32 absent 125 2 11
#> 33 absent 130 5 13
#> 35 absent 140 5 11
#> 37 absent 1 3 9
#> 39 absent 20 6 9
#> 40 present 91 5 12
#> 42 absent 35 3 13
#> 46 present 139 3 10
#> 48 absent 131 5 13
#> 50 absent 177 2 14
#> 51 absent 68 5 10
#> 57 absent 2 3 13
#> 59 absent 51 7 9
#> 60 absent 102 3 13
#> 66 absent 17 4 10
#> 68 absent 159 4 13
#> 69 absent 18 4 11
#> 71 absent 158 5 14
#> 72 absent 127 4 12
#> 74 absent 206 4 10
#> 77 present 157 3 13
#> 78 absent 26 7 13
#> 79 absent 120 2 13
#> 81 absent 36 4 13
Created on 2023-08-14 with reprex v2.0.2