4

I'd like to print a decision tree in text nicely. For example, I can print the tree object itself:

library(rpart)

f = as.formula('Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species')
fit = rpart(f, data = iris, control = rpart.control(xval = 3))

fit

yields

n= 150 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 150 102.1683000 5.843333  
   2) Petal.Length< 4.25 73  13.1391800 5.179452  
     4) Petal.Length< 3.4 53   6.1083020 5.005660  
       8) Sepal.Width< 3.25 20   1.0855000 4.735000 *
       9) Sepal.Width>=3.25 33   2.6696970 5.169697 *
... # omitted

partykit prints it neater:

library(partykit)

as.party(fit)

yields

Model formula:
Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species

Fitted party:
[1] root
|   [2] Petal.Length < 4.25
|   |   [3] Petal.Length < 3.4
|   |   |   [4] Sepal.Width < 3.25: 4.735 (n = 20, err = 1.1)
|   |   |   [5] Sepal.Width >= 3.25: 5.170 (n = 33, err = 2.7)
|   |   [6] Petal.Length >= 3.4: 5.640 (n = 20, err = 1.2)
...# omitted

Number of inner nodes:    6
Number of terminal nodes: 7

Is there a way I have have more control? Eg, I don't want to print n and err, or want standard deviation instead of err printed.

Achim Zeileis
  • 15,710
  • 1
  • 39
  • 49
YJZ
  • 3,934
  • 11
  • 43
  • 67

2 Answers2

1

Not a very elegant answer, but if you just want to get rid of n= and err= you can capture the output and edit it.

CO = capture.output(print(as.party(fit)))
CO2 = sub("\\(.*\\)", "", CO)
cat(paste(CO2, collapse="\n"))

Model formula:
Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species

Fitted party:
[1] root
|   [2] Petal.Length < 4.25
|   |   [3] Petal.Length < 3.4
|   |   |   [4] Sepal.Width < 3.25: 4.735 
|   |   |   [5] Sepal.Width >= 3.25: 5.170 
|   |   [6] Petal.Length >= 3.4: 5.640 
|   [7] Petal.Length >= 4.25

I am not sure what standard deviation you want to insert, but I expect you could edit that in the same way.

G5W
  • 36,531
  • 10
  • 47
  • 80
1

The print() method for party objects is quite flexible and can be controlled through various panel functions and customizations. See ?print.party for an overview. The documentation is somewhat short and technical, though.

In your case, the easiest solution is to set up a function of the response y, the case weights w (defaulting to all 1 in your case), and the desired number of digits:

myfun <- function(y, w, digits = 2) {
  n <- sum(w)
  m <- weighted.mean(y, w)
  s <- sqrt(weighted.mean((y - m)^2, w) * n/(n - 1))
  sprintf("%s (serr = %s)",
    round(m, digits = digits),
    round(s, digits = digits))
}

And then you can pass that to your print() call:

p <- as.party(fit)
print(p, FUN = myfun)
## Model formula:
## Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species
## 
## Fitted party:
## [1] root
## |   [2] Petal.Length < 4.25
## |   |   [3] Petal.Length < 3.4
## |   |   |   [4] Sepal.Width < 3.25: 4.735 (serr = 0.239)
## |   |   |   [5] Sepal.Width >= 3.25: 5.17 (serr = 0.289) 
## |   |   [6] Petal.Length >= 3.4: 5.64 (serr = 0.25)  
## |   [7] Petal.Length >= 4.25
## |   |   [8] Petal.Length < 6.05
## |   |   |   [9] Petal.Length < 5.15
## |   |   |   |   [10] Sepal.Width < 3.05: 6.055 (serr = 0.404)
## |   |   |   |   [11] Sepal.Width >= 3.05: 6.53 (serr = 0.38)  
## |   |   |   [12] Petal.Length >= 5.15: 6.604 (serr = 0.302)
## |   |   [13] Petal.Length >= 6.05: 7.578 (serr = 0.228)
## 
## Number of inner nodes:    6
## Number of terminal nodes: 7
Achim Zeileis
  • 15,710
  • 1
  • 39
  • 49