6

Say I have

head(kyphosis)
inTrain <- sample(1:nrow(kyphosis), 45, replace = F)
TRAIN_KYPHOSIS <- kyphosis[inTrain,]
TEST_KYPHOSIS <- kyphosis[-inTrain,]

(kyph_tree <- rpart(Number ~ ., data = TRAIN_KYPHOSIS))

How to get the terminal node from the fitted object for each observation in TEST_KYPHOSIS?

How do I get a summary, such as the deviance and the predicted value from the terminal node which each test observation maps to?

goldisfine
  • 4,742
  • 11
  • 59
  • 83

2 Answers2

8

rpart actually has this functionality but it's not exposed (strangely enough, it's a rather obvious requirement).

predict_nodes <-
    function (object, newdata, na.action = na.pass) {
        where <-
            if (missing(newdata)) 
                object$where
            else {
                if (is.null(attr(newdata, "terms"))) {
                    Terms <- delete.response(object$terms)
                    newdata <- model.frame(Terms, newdata, na.action = na.action, 
                                           xlev = attr(object, "xlevels"))
                    if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                        .checkMFClasses(cl, newdata, TRUE)
                }
                rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata))
            }
        as.integer(row.names(object$frame))[where]
    }

And then:

> predict_nodes(kyph_tree, TEST_KYPHOSIS)
 [1] 5 3 4 3 3 5 5 3 3 3 3 5 5 4 3 5 4 3 3 3 3 4 3 4 4 5 5 3 4 4 3 5 3 5 5 5
VitoshKa
  • 8,387
  • 3
  • 35
  • 59
  • 1
    Why does `rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata))` lead to a predicted node? – goldisfine Sep 10 '15 at 19:27
  • 2
    @goldisfine, because this is how rpart computes the predicted nodes internally. This functinallity is used internally, but not exposed. – VitoshKa Sep 18 '15 at 09:22
  • @VitoshKa Thank for posting solution. This is such a fundamental part of tree! It is almost unusable without this part. – user1700890 Aug 12 '17 at 21:58
5

One option is to convert the rpart object to an object of class party from the partykit package. That provides a general toolkit for dealing with recursive partytions. The conversion is simple:

library("partykit")
(kyph_party <- as.party(kyph_tree))

Model formula:
Number ~ Kyphosis + Age + Start

Fitted party:
[1] root
|   [2] Start >= 15.5: 2.933 (n = 15, err = 10.9)
|   [3] Start < 15.5
|   |   [4] Age >= 112.5: 3.714 (n = 14, err = 18.9)
|   |   [5] Age < 112.5: 5.125 (n = 16, err = 29.8)

Number of inner nodes:    2
Number of terminal nodes: 3

(For exact reproducibility run the code from your question with set.seed(1) prior to running my code.)

For objects of this class there are somewhat more flexible methods for plot(), predict(), fitted(), etc. For example, plot(kyph_party) yields a more informative display than the default plot(kyph_tree). The fitted() method extracts a two-column data.frame with the fitted node numbers and the observed responses on the training data.

kyph_fit <- fitted(kyph_party)
head(kyph_fit, 3)

  (fitted) (response)
1        5          6
2        2          2
3        4          3

With this you can easily compute any quantity you are interested in, e.g., the means, median, or residual sums of squares within each node.

tapply(kyph_fit[,2], kyph_fit[,1], mean)

       2        4        5 
2.933333 3.714286 5.125000 

tapply(kyph_fit[,2], kyph_fit[,1], median)

2 4 5 
3 4 5 

tapply(kyph_fit[,2], kyph_fit[,1], function(x) sum((x - mean(x))^2))

       2        4        5 
10.93333 18.85714 29.75000 

Instead of the simple tapply() you can use any other function of your choice to compute the tables of grouped statistics.

Now to learn which observation from the test data TEST_KYPHOSIS to which node in the tree you can simply use the predict(..., type = "node") method:

kyph_pred <- predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node")
head(kyph_pred)

 2  3  4  6  7 10 
 4  4  5  2  2  5 
Achim Zeileis
  • 15,710
  • 1
  • 39
  • 49
  • 1
    You solution yields the same result as `kyph_tree$where` and it is different from what results obtained with VitoshKa solution below. – user1700890 Aug 12 '17 at 21:45
  • The `predict_nodes()` solution by VitoshKa and the `predict(..., type = "node")` solution in `partykit` do not yield exactly the same _labels_ because the node IDs are assigned a bit differently. But the information is in fact equivalent. Check out: `table(predict_nodes(kyph_tree, TEST_KYPHOSIS), predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node"))`. It may not be diagnoal due to the different labels but there is a 1:1 match. But this is because `partykit` provides general solutions for recursive partitining which are not specific to `rpart`. – Achim Zeileis Aug 13 '17 at 20:53
  • Thank you for your reply. I am struggling to understand what `kyph_tree$where` returns. It does not look like it returns terminal node label – user1700890 Aug 13 '17 at 21:00
  • It's, in fact, simply the terminal node label/ID as on the training data `TRAIN_KYPHOSIS` For example, check out, `table(kyph_tree$where)`. If you compare that with `table(fitted(kyph_party)[,1])` or `table(predict(kyph_party, type = "node"))` you will see the same absolute frequencies but possibly different labels (depending on the structure of the tree). – Achim Zeileis Aug 13 '17 at 21:30
  • thank you again. I am not sure I understand what you mean by label/ID. `unique(kyph_tree$where)` returns `3,7,5,6`, but if you look at resulting tree the terminal nodes are: `4, 10, 11, 3'. I used 'fancyRpartPlot(kyph_tree)' to plot the tree. – user1700890 Aug 13 '17 at 21:39
  • 1
    Good point. The IDs in `$where` pertain to rows in `$frame`, i.e., are not the same labels used in printing/plotting but correspond to IDs in a summary table. If you use the row names from that summary table you get the same labels used in printing/plotting: `table(as.numeric(rownames(kyph_tree$frame))[kyph_tree$where])`. – Achim Zeileis Aug 14 '17 at 08:10