There are various ways to extract the full data pertaining to a certain node and compute any quantity you are interested in. For the distribution of a classification tree one way is to coerce to the simpleparty
class which stores the distribution
in the info
slots of each node.
Using the example from the vignette you mentioned, you can first fit the full constparty
tree:
library("partykit")
data("GlaucomaM", package = "TH.data")
gtree <- ctree(Class ~ ., data = GlaucomaM)
And then coerce to simpleparty
:
gtree <- as.simpleparty(gtree)
Then you can extract a list of distributions from each node, bind it into a table, and compute the proportions:
tab <- nodeapply(gtree, nodeids(gtree), function(node) node$info$distribution)
tab <- do.call(rbind, tab)
proportions(tab, 1)
## glaucoma normal
## 1 0.50000000 0.50000000
## 2 0.86206897 0.13793103
## 3 0.93670886 0.06329114
## 4 0.12500000 0.87500000
## 5 0.21100917 0.78899083
## 6 0.09230769 0.90769231
## 7 0.38636364 0.61363636
You can also adapt the panel function for the printing, re-using the functions used in print.simpleparty
:
simpleprint <- function(node) formatinfo_node(node,
FUN = partykit:::.make_formatinfo_simpleparty(gtree),
default = "*", prefix = ": ")
print(gtree, inner_panel = simpleprint)
## Model formula:
## Class ~ ag + at + as + an + ai + eag + eat + eas + ean + eai +
## abrg + abrt + abrs + abrn + abri + hic + mhcg + mhct + mhcs +
## mhcn + mhci + phcg + phct + phcs + phcn + phci + hvc + vbsg +
## vbst + vbss + vbsn + vbsi + vasg + vast + vass + vasn + vasi +
## vbrg + vbrt + vbrs + vbrn + vbri + varg + vart + vars + varn +
## vari + mdg + mdt + mds + mdn + mdi + tmg + tmt + tms + tmn +
## tmi + mr + rnf + mdic + emd + mv
##
## Fitted party:
## [1] root
## | [2] vari <= 0.059: glaucoma (n = 87, err = 13.8%)
## | | [3] vasg <= 0.066: glaucoma (n = 79, err = 6.3%)
## | | [4] vasg > 0.066: normal (n = 8, err = 12.5%)
## | [5] vari > 0.059: normal (n = 109, err = 21.1%)
## | | [6] tms <= -0.066: normal (n = 65, err = 9.2%)
## | | [7] tms > -0.066: normal (n = 44, err = 38.6%)
##
## Number of inner nodes: 3
## Number of terminal nodes: 4