1

when fitting with rpart, it returns the "where" vector which tells which leave each record in the training dataset is on the tree. Is there a function which return something similar to this "where" vector for a test dataset?

Meng zhao
  • 201
  • 1
  • 6
  • not sure what you are asking, you want to get the subset for any node? https://stackoverflow.com/questions/36748531/getting-the-observations-in-a-rparts-node-i-e-cart – rawr Mar 09 '18 at 17:37
  • Just something similar to the rpart$where vector of numbers, which tells which leave a record is on. But this is only for the training data. I wonder if I can get something similar for a test dataset. The predict function doesn't seem to do this, it only gives predicted values. – Meng zhao Mar 09 '18 at 17:39

2 Answers2

0

I think the partykit package does what you want

library('rpart')
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
fit
rpart.plot::rpart.plot(fit)

enter image description here

Check with same data

set.seed(1)
idx <- sample(nrow(kyphosis), 5L)
fit$where[idx]
# 22 30 46 71 16 
#  9  3  7  7  3 

library('partykit')
fit <- as.party(fit)
predict(fit, kyphosis[idx, ], type = 'node')
# 22 30 46 71 16 
#  9  3  7  7  3 

Check with new data

dd <- kyphosis[idx, ]
set.seed(1)
dd[] <- lapply(dd, sample)
predict(fit, dd, type = 'node')
# 22 30 46 71 16 
#  5  3  7  9  3 

## so #46 should meet criteria for the 7th leaf:
with(kyphosis[46, ],
     Start  >= 8.5  &  # node 1
       Start < 14.5 &  # node 2
       Age  >= 55   &  # node 4
       Age  >= 111     # node 6
)
# [1] TRUE
rawr
  • 20,481
  • 4
  • 44
  • 78
  • Thanks for the reply. I tried, but it seems the only accepted values for 'type' are “vector”, “prob”, “class”, “matrix” – Meng zhao Mar 09 '18 at 19:38
0

As you mention, the function predict.rpart in the rpart package doesn't have a where option (to show the leaf node number associated with a prediction). However, the rpart.predict function in the rpart.plot package will do this. For example

> library(rpart.plot)
> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart.predict(fit, newdata=kyphosis[1:3,], nn=TRUE)

gives (note the node number nn column):

   absent present nn
1 0.42105 0.57895  3
2 0.85714 0.14286 22
3 0.42105 0.57895  3

And

> rpart.predict(fit, newdata=kyphosis[1:3,], nn=TRUE)$nn

gives just the where node numbers:

[1]  3 22  3

To show the rule for each prediction use

> rpart.predict(fit, newdata=kyphosis[1:5,], rules=TRUE)

which gives

   absent present
1 0.42105 0.57895 because Start <  9
2 0.85714 0.14286 because Start is 9 to 15 & Age >= 111
3 0.42105 0.57895 because Start <  9
Stephen Milborrow
  • 976
  • 10
  • 14