I created decision tree with Party package in R. I'm trying to get the route/branch with the maximum value.
It can be mean value that came from box-plot
and it can be probability value that came from binary tree
(source: rdatamining.com)
I created decision tree with Party package in R. I'm trying to get the route/branch with the maximum value.
It can be mean value that came from box-plot
and it can be probability value that came from binary tree
(source: rdatamining.com)
This can be done pretty easily actually, though while your definition of maximum value is clear for a regression tree, it is not very clear for a classification tree, as in each node different level can have it's own maximum
Either way, here's a pretty simple helper function that will return you the predictions for each type of tree
GetPredicts <- function(ct){
f <- function(ct, i) nodes(ct, i)[[1]]$prediction
Terminals <- unique(where(ct))
Predictions <- sapply(Terminals, f, ct = ct)
if(is.matrix(Predictions)){
colnames(Predictions) <- Terminals
return(Predictions)
} else {
return(setNames(Predictions, Terminals))
}
}
Now luckily you've took your trees from the examples of ?ctree
, so we can test them (next time, please provide the code you used yourself)
Regression Tree (your frist tree)
## load the package and create the tree
library(party)
airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq,
controls = ctree_control(maxsurrogate = 3))
plot(airct)
Now, test the function
res <- GetPredicts(airct)
res
# 5 3 6 9 8
# 18.47917 55.60000 31.14286 48.71429 81.63333
So we've got the predictions per each terminal node. You can easily proceed with which.max(res)
from here (I'll leave it for you to decide)
Classification tree (your second tree)
irisct <- ctree(Species ~ .,data = iris)
plot(irisct, type = "simple")
Run the function
res <- GetPredicts(irisct)
res
# 2 5 6 7
# [1,] 1 0.00000000 0.0 0.00000000
# [2,] 0 0.97826087 0.5 0.02173913
# [3,] 0 0.02173913 0.5 0.97826087
Now, the output is a bit harder to read because each class has it's own probabilities. You could make this a bit more readable using
row.names(res) <- levels(iris$Species)
res
# 2 5 6 7
# setosa 1 0.00000000 0.0 0.00000000
# versicolor 0 0.97826087 0.5 0.02173913
# virginica 0 0.02173913 0.5 0.97826087
The, you could do something like the following in order to get the overall maximum value
which(res == max(res), arr.ind = TRUE)
# row col
# setosa 1 1
For column/row maxes, you could do
matrixStats::colMaxs(res)
# [1] 1.0000000 0.9782609 0.5000000 0.9782609
matrixStats::rowMaxs(res)
# [1] 1.0000000 0.9782609 0.9782609
But, again, I'll leave to you to decide on how to proceed from here.