when fitting with rpart, it returns the "where" vector which tells which leave each record in the training dataset is on the tree. Is there a function which return something similar to this "where" vector for a test dataset?
Asked
Active
Viewed 449 times
1
-
not sure what you are asking, you want to get the subset for any node? https://stackoverflow.com/questions/36748531/getting-the-observations-in-a-rparts-node-i-e-cart – rawr Mar 09 '18 at 17:37
-
Just something similar to the rpart$where vector of numbers, which tells which leave a record is on. But this is only for the training data. I wonder if I can get something similar for a test dataset. The predict function doesn't seem to do this, it only gives predicted values. – Meng zhao Mar 09 '18 at 17:39
2 Answers
0
I think the partykit
package does what you want
library('rpart')
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
fit
rpart.plot::rpart.plot(fit)
Check with same data
set.seed(1)
idx <- sample(nrow(kyphosis), 5L)
fit$where[idx]
# 22 30 46 71 16
# 9 3 7 7 3
library('partykit')
fit <- as.party(fit)
predict(fit, kyphosis[idx, ], type = 'node')
# 22 30 46 71 16
# 9 3 7 7 3
Check with new data
dd <- kyphosis[idx, ]
set.seed(1)
dd[] <- lapply(dd, sample)
predict(fit, dd, type = 'node')
# 22 30 46 71 16
# 5 3 7 9 3
## so #46 should meet criteria for the 7th leaf:
with(kyphosis[46, ],
Start >= 8.5 & # node 1
Start < 14.5 & # node 2
Age >= 55 & # node 4
Age >= 111 # node 6
)
# [1] TRUE

rawr
- 20,481
- 4
- 44
- 78
-
Thanks for the reply. I tried, but it seems the only accepted values for 'type' are “vector”, “prob”, “class”, “matrix” – Meng zhao Mar 09 '18 at 19:38
0
As you mention, the function predict.rpart
in the rpart
package
doesn't have a where
option (to show the leaf node number associated
with a prediction).
However, the rpart.predict
function in the rpart.plot
package
will do this. For example
> library(rpart.plot)
> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart.predict(fit, newdata=kyphosis[1:3,], nn=TRUE)
gives (note the node number nn
column):
absent present nn
1 0.42105 0.57895 3
2 0.85714 0.14286 22
3 0.42105 0.57895 3
And
> rpart.predict(fit, newdata=kyphosis[1:3,], nn=TRUE)$nn
gives just the where
node numbers:
[1] 3 22 3
To show the rule for each prediction use
> rpart.predict(fit, newdata=kyphosis[1:5,], rules=TRUE)
which gives
absent present
1 0.42105 0.57895 because Start < 9
2 0.85714 0.14286 because Start is 9 to 15 & Age >= 111
3 0.42105 0.57895 because Start < 9

Stephen Milborrow
- 976
- 10
- 14