3

I can't find the split values (or other data) for nodes in an rpart object. I see it with summary(sample_model) but not in the list or data frame

Some sample data

foo.df <- structure(list(type = c("fudai", "fudai", "fudai", "fudai", "fudai", 
                              "fudai", "fudai", "tozama", "fudai", "fudai", "tozama", "tozama", 
                              "fudai", "tozama", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "fudai", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "tozama", "fudai", "tozama", "tozama", "tozama", "tozama", "fudai", 
                              "fudai", "tozama", "tozama"), distance = c(12.5366985071383, 
                                                                         272.697138147139, 40.4780423740381, 109.806349869662, 147.781805212839, 
                                                                         89.4280438527415, 49.1425850803745, 555.414271440522, 119.365138867582, 
                                                                         182.902536555383, 310.019126513348, 277.122207392514, 214.510428881317, 
                                                                         235.111617874157, 104.494518693549, 50.7561853895564, 343.308898045237, 
                                                                         151.796857505073, 36.0391449169937, 30.8214406651022, 343.294467363406, 
                                                                         135.841501028422, 154.798119311647, 317.739208576563, 3.33794280697559, 
                                                                         98.9182898110913, 422.915369767251, 194.957988642709, 87.6548263591412, 
                                                                         187.571370158631, 236.292608259126, 17.915709270268, 193.548578374405, 
                                                                         262.190146422316, 21.6219797945323, 121.199009527283, 261.670997612517, 
                                                                         202.2051991431, 125.418459536787, 275.964068539003, 190.112226847932, 
                                                                         20.1753302760961, 488.80323504215, 579.25515722891, 233.500797034697, 
                                                                         207.588349435329, 183.770003408524, 168.739293254246, 313.140075747773, 
                                                                         131.69228390613), age = c(1756, 1711, 1712, 1746, 1868, 1866, 
                                                                                                   1682, 1617, 1771, 1764, 1672, 1636, 1864, 1704, 1762, 1868, 1694, 
                                                                                                   1749, 1703, 1616, 1691, 1702, 1723, 1683, 1742, 1691, 1623, 1721, 
                                                                                                   1704, 1745, 1749, 1723, 1639, 1661, 1843, 1845, 1669, 1698, 1698, 
                                                                                                   1664, 1868, 1633, 1783, 1642, 1615, 1648, 1734, 1758, 1725, 1635
                                                                         )), class = c("tbl_df", "tbl", "data.frame"), row.names = c(NA, 
                                                                                                                                     -50L))

And a basic model

library("rpart")
sample_model <- rpart(formula = type ~ ., 
                  data = sample_data, 
                  method = "class",
                  control = rpart.control(xval = 50, minbucket = 5, cp = 0.05),
                  parms = list(split = "gini"))

The rpart documentation say that there's supposed to be a column(s) in sample_model$frame called "splits" but it's not there. To quote: "splits, a two column matrix of left and right split labels for each node" https://www.rdocumentation.org/packages/rpart/versions/4.1-15/topics/rpart.object

Where are those columns in in sample_model$frame or sample_model? However, I see the data I want in

summary(sample_model)

What's going on?

Mark R
  • 775
  • 1
  • 8
  • 23

2 Answers2

4

The docs are indeed outdated. Here is an extractor derived by inspecting summary.rpart function:


rpart_splits <- function(fit, digits = getOption("digits")) {
  splits <- fit$splits
  if (!is.null(splits)) {
    ff <- fit$frame
    is.leaf <- ff$var == "<leaf>"
    n <- nrow(splits)
    nn <- ff$ncompete + ff$nsurrogate + !is.leaf
    ix <- cumsum(c(1L, nn))
    ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
    type <- rep.int("surrogate", n)
    type[ix_prim[ix_prim <= n]] <- "primary"
    type[ix[ix <= n]] <- "main"
    left <- character(nrow(splits))
    side <- splits[, 2L]
    for (i in seq_along(left)) {
      left[i] <- if (side[i] == -1L)
                   paste("<", format(signif(splits[i, 4L], digits)))
                 else if (side[i] == 1L)
                   paste(">=", format(signif(splits[i, 4L], digits)))
                 else {
                   catside <- fit$csplit[splits[i, 4L], 1:side[i]]
                   paste(c("L", "-", "R")[catside], collapse = "", sep = "")
                 }
    }
    cbind(data.frame(var = rownames(splits),
                     type = type,
                     node = rep(as.integer(row.names(ff)), times = nn),
                     ix = rep(seq_len(nrow(ff)), nn),
                     left = left),
          as.data.frame(splits, row.names = F))
  }
}

Filter on type == "main" to get only the main splits:

> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart_splits(fit)
      var      type node ix    left count ncat    improve index       adj
1   Start      main    1  1  >= 8.5    81    1 6.76232996   8.5 0.0000000
2  Number   primary    1  1   < 5.5    81   -1 2.86679493   5.5 0.0000000
3     Age   primary    1  1  < 39.5    81   -1 2.25021152  39.5 0.0000000
4  Number surrogate    1  1   < 6.5     0   -1 0.80246914   6.5 0.1578947
5   Start      main    2  2 >= 14.5    62    1 1.02052786  14.5 0.0000000
6     Age   primary    2  2    < 55    62   -1 0.68486352  55.0 0.0000000
7  Number   primary    2  2   < 4.5    62   -1 0.29753321   4.5 0.0000000
8  Number surrogate    2  2   < 3.5     0   -1 0.64516129   3.5 0.2413793
9     Age surrogate    2  2    < 16     0   -1 0.59677419  16.0 0.1379310
10    Age      main    5  4    < 55    33   -1 1.24675325  55.0 0.0000000
11  Start   primary    5  4 >= 12.5    33    1 0.28877005  12.5 0.0000000
12 Number   primary    5  4  >= 3.5    33    1 0.17532468   3.5 0.0000000
13  Start surrogate    5  4   < 9.5     0   -1 0.75757576   9.5 0.3333333
14 Number surrogate    5  4  >= 5.5     0    1 0.69696970   5.5 0.1666667
15    Age      main   11  6  >= 111    21    1 1.71428571 111.0 0.0000000
16  Start   primary   11  6 >= 12.5    21    1 0.79365079  12.5 0.0000000
17 Number   primary   11  6  >= 3.5    21    1 0.07142857   3.5 0.0000000

VitoshKa
  • 8,387
  • 3
  • 35
  • 59
  • Assuming regressing a given outcome `y` on two covariates `x1` and `x2`, is there a way to use this information to plot the axis-aligned splits in a 2d plot? – riccardo-df Jun 10 '22 at 14:56
2

I see than now, but it doesn't seem to describe the current structure. The $splits item is a separate list element:

  sample_model$splits

 #----------

         count ncat  improve     index adj
distance    50   -1 9.134639  274.3306   0
age         50    1 7.910588 1687.0000   0
age         39    1 6.062937 1654.5000   0
distance    39   -1 1.950142  188.8418   0

To see the full structure of the sample_model, do this:

str(sample_model)

I was unable to confirm my hunch about the docs lagging the code:

news(grepl('splits', Text), 'rpart')     #--------------------

Changes in version 4.1-0

Surrogate splits are now considered only if they send two or more cases with non-zero weight each way. For numeric/ordinal variables the restriction to non-zero weights is new: for categorical variables this is a new restriction. Surrogate splits which improve only by rounding error over the default split are no longer returned. Where weights and missing values are present, the splits component for some of these was not returned correctly.

Changes in version 4.0-1

The other major change was an error for asymmetric loss matrices, prompted by a user query. With L=loss asymmetric, the altered priors were computed incorrectly - they were using L' instead of L. Upshot - the tree would not not necessarily choose optimal splits for the given loss matrix. Once chosen, splits were evaluated correctly. The printed “improvement” values are of course the wrong ones as well. It is interesting that for my little test case, with L quite asymmetric, the early splits in the tree are unchanged - a good split still looks good.

To get a canonical answer you would need to contact the maintainer:

 maintainer('rpart')
IRTFM
  • 258,963
  • 21
  • 364
  • 487
  • I thought that might be the case, but in help, splits is document both under frame and then again under splits. Also sample_model$splits has 4 rows for 2 nodes. What marks the actual nodes (rather than competitor or surrogate), outside of the higher improve value? – Mark R May 19 '19 at 18:02
  • I'm guessing that the documentation may not have been updated when splits were broken out of frame, but that's just a guess. – IRTFM May 19 '19 at 18:06
  • So to get the actual split values I could use sample_model$splits[c(1,3),4] but that seems odd. I could set maxcompete to zero in rpart. I guess my core question is "how do I capture the simple format of print(sample_model)" – Mark R May 19 '19 at 19:10
  • 1
    You could examine the code returned by `getAnywhere(print.rpart)` and see which portions you find useful. – IRTFM May 19 '19 at 19:35
  • Thanks! I was expecting to just extract some obviously named parameters, as with lm(), for example. But now I'm starting to see the underlying values used to calculate the summary. – Mark R May 19 '19 at 19:55
  • Any idea how to get the simple "class counts" produced by summary(sample_model)? – Mark R May 20 '19 at 13:32
  • I'm sure the answer is in the code for `summary.rpart`. – IRTFM May 20 '19 at 16:31