0

I am trying to plot a regression tree generated with rpart using partykit. The code that generates the tree is this one:

library("rpart")
fit <- rpart(Price ~ Mileage + Type + Country, cu.summary)
library("partykit")
tree.2 <- as.party(fit)

plot(tree.2, type = "simple", terminal_panel = node_boxplot(tree.2,
                                                            col = "black", fill = "lightgray", width = 0.5, yscale = NULL,
                                                            ylines = 3, cex = 0.5, id = TRUE))

I am trying to modify the boxplots on the terminal nodes so that the y axis is on the log scale.

I am aware that when trying to make a boxplot all we have to do is to specify boxplot(data, log="y"). Which is why I tried to modify the function node_boxplot only in the single line where the function boxplot is used. However I keep getting the same graph. Is there something I am missing? Any feedback would be greatly appreciated.

node_boxplot2<-function (obj, col = "black", fill = "lightgray", bg = "white", 
          width = 0.5, yscale = NULL, ylines = 3, cex = 0.5, id = TRUE, 
          mainlab = NULL, gp = gpar()) 
{
  y <- log(obj$fitted[["(response)"]])
  stopifnot(is.numeric(y))
  if (is.null(yscale)) 
    yscale <- range(y) +c(0,0.1)* diff(range(y))
  rval <- function(node) {
    nid <- id_node(node)
    dat <- data_party(obj, nid)
    yn <- dat[["(response)"]]
    wn <- dat[["(weights)"]]
    if (is.null(wn)) 
      wn <- rep(1, length(yn))
    x <- boxplot(rep.int(yn, wn),plot = FALSE)
    top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3, 
                                            widths = unit(c(ylines, 1, 1), c("lines", "null", 
                                                                             "lines")), heights = unit(c(1, 1), c("lines", 
                                                                                                                  "null"))), width = unit(1, "npc"), height = unit(1, 
                                                                                                                                                                   "npc") - unit(2, "lines"), name = paste("node_boxplot", 
                                                                                                                                                                                                           nid, sep = ""), gp = gp)
    pushViewport(top_vp)
    grid.rect(gp = gpar(fill = bg, col = 0))
    top <- viewport(layout.pos.col = 2, layout.pos.row = 1)
    pushViewport(top)
    if (is.null(mainlab)) {
      mainlab <- if (id) {
        function(id, nobs) sprintf("Node %s (n = %s)", 
                                   id, nobs)
      }
      else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
    }
    if (is.function(mainlab)) {
      mainlab <- mainlab(names(obj)[nid], sum(wn))
    }
    grid.text(mainlab)
    popViewport()
    plot <- viewport(layout.pos.col = 2, layout.pos.row = 2, 
                     xscale = c(0, 1), yscale = yscale, name = paste0("node_boxplot", 
                                                                      nid, "plot"), clip = FALSE)
    pushViewport(plot)
    grid.yaxis()
    grid.rect(gp = gpar(fill = "transparent"))
    grid.clip()
    xl <- 0.5 - width/4
    xr <- 0.5 + width/4
    grid.lines(unit(c(xl, xr), "npc"), unit(x$stats[1], "native"), 
               gp = gpar(col = col))
    grid.lines(unit(0.5, "npc"), unit(x$stats[1:2], "native"), 
               gp = gpar(col = col, lty = 2))
    grid.rect(unit(0.5, "npc"), unit(x$stats[2], "native"), 
              width = unit(width, "npc"), height = unit(diff(x$stats[c(2, 
                                                                       4)]), "native"), just = c("center", "bottom"), 
              gp = gpar(col = col, fill = fill))
    grid.lines(unit(c(0.5 - width/2, 0.5 + width/2), "npc"), 
               unit(x$stats[3], "native"), gp = gpar(col = col, 
                                                     lwd = 2))
    grid.lines(unit(0.5, "npc"), unit(x$stats[4:5], "native"), 
               gp = gpar(col = col, lty = 2))
    grid.lines(unit(c(xl, xr), "npc"), unit(x$stats[5], "native"), 
               gp = gpar(col = col))
    n <- length(x$out)
    if (n > 0) {
      index <- 1:n
      if (length(index) > 0) 
        grid.points(unit(rep.int(0.5, length(index)), 
                         "npc"), unit(x$out[index], "native"), size = unit(cex, 
                                                                           "char"), gp = gpar(col = col))
    }
    upViewport(2)
  }
  return(rval)
}
Alvaro GC
  • 46
  • 5

1 Answers1

2

(1) If plotting is more appropriate on a log-scale, then I would usually expect that growing the tree is also better done on a log-scale. Here, you could simply use rpart(log(Price) ~ ...).

(2) If you only want to draw a different scale in the node boxplots, a little bit more work is needed because the box plots are drawn "by hand" using the grid.*() functions. In the code below, I transform both the overall response and the response in the node to be plotted by taking logs. And then I just modify the grid.yaxis() as needed. The function node_logboxplot() is simply a copy of node_boxplot() with a few simple modifications (marked by #!!#). With this you can do

plot(tree.2, terminal_panel = node_logboxplot)

node_logboxplot

compared to

plot(tree.2, terminal_panel = node_boxplot)

node_boxplot

Modified panel function:

node_logboxplot <- function(obj,
                         col = "black",
                 fill = "lightgray",
             bg = "white",
                 width = 0.5,
                 yscale = NULL,
                 ylines = 3,
             cex = 0.5,
                 id = TRUE,
                         mainlab = NULL, 
             gp = gpar())
{
    y <- log(obj$fitted[["(response)"]]) #!!# log-transform overall response
    stopifnot(is.numeric(y))

    if (is.null(yscale)) 
        yscale <- range(y) + c(-0.1, 0.1) * diff(range(y))

    #!!# compute yaxis labels on original scale
    yaxis <- pretty(exp(y))
    yaxis <- yaxis[yaxis > 0]

    ### panel function for boxplots in nodes
    rval <- function(node) {

        ## extract data
    nid <- id_node(node)
    dat <- data_party(obj, nid)
    yn <- log(dat[["(response)"]]) #!!# log-transform response in node
    wn <- dat[["(weights)"]]
    if(is.null(wn)) wn <- rep(1, length(yn))

        ## parameter setup
    x <- boxplot(rep.int(yn, wn), plot = FALSE)

        top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
                           widths = unit(c(ylines, 1, 1), 
                                         c("lines", "null", "lines")),  
                           heights = unit(c(1, 1), c("lines", "null"))),
                           width = unit(1, "npc"), 
                           height = unit(1, "npc") - unit(2, "lines"),
               name = paste("node_boxplot", nid, sep = ""),
               gp = gp)

        pushViewport(top_vp)
        grid.rect(gp = gpar(fill = bg, col = 0))

        ## main title
        top <- viewport(layout.pos.col=2, layout.pos.row=1)
        pushViewport(top)
        if (is.null(mainlab)) { 
      mainlab <- if(id) {
        function(id, nobs) sprintf("Node %s (n = %s)", id, nobs)
      } else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
        }
    if (is.function(mainlab)) {
          mainlab <- mainlab(names(obj)[nid], sum(wn))
    }
        grid.text(mainlab)
        popViewport()

        plot <- viewport(layout.pos.col = 2, layout.pos.row = 2,
                         xscale = c(0, 1), yscale = yscale,
             name = paste0("node_boxplot", nid, "plot"),
             clip = FALSE)

        pushViewport(plot)

        grid.yaxis(at = log(yaxis), label = yaxis) #!!# use pre-computed axis labels
        grid.rect(gp = gpar(fill = "transparent"))
    grid.clip()

    xl <- 0.5 - width/4
    xr <- 0.5 + width/4

        ## box & whiskers
        grid.lines(unit(c(xl, xr), "npc"), 
                   unit(x$stats[1], "native"), gp = gpar(col = col))
        grid.lines(unit(0.5, "npc"), 
                   unit(x$stats[1:2], "native"), gp = gpar(col = col, lty = 2))
        grid.rect(unit(0.5, "npc"), unit(x$stats[2], "native"), 
                  width = unit(width, "npc"), height = unit(diff(x$stats[c(2, 4)]), "native"),
                  just = c("center", "bottom"), 
                  gp = gpar(col = col, fill = fill))
        grid.lines(unit(c(0.5 - width/2, 0.5+width/2), "npc"), 
                   unit(x$stats[3], "native"), gp = gpar(col = col, lwd = 2))
        grid.lines(unit(0.5, "npc"), unit(x$stats[4:5], "native"), 
                   gp = gpar(col = col, lty = 2))
        grid.lines(unit(c(xl, xr), "npc"), unit(x$stats[5], "native"), 
                   gp = gpar(col = col))

        ## outlier
        n <- length(x$out)
        if (n > 0) {
            index <- 1:n ## which(x$out > yscale[1] & x$out < yscale[2])
            if (length(index) > 0)
                grid.points(unit(rep.int(0.5, length(index)), "npc"), 
                            unit(x$out[index], "native"),
                            size = unit(cex, "char"), gp = gpar(col = col))
        }

        upViewport(2)
    }

    return(rval)
}
class(node_logboxplot) <- "grapcon_generator"
Achim Zeileis
  • 15,710
  • 1
  • 39
  • 49
  • Thank you for your answer! I definitely need to take a closer look to all the `grid.*()` functions. – Alvaro GC Mar 28 '18 at 19:38
  • 1
    They are quite similar to the corresponding base graphics functions but the arguments are somewhat streamlined. Also you have possibilities to use different coordinate scales in different viewports which helps with more complex displays. – Achim Zeileis Mar 28 '18 at 19:57