I've gone through the data structure of DecisionTreeClassifier in scikit-learn.
Simply speaking, I just saw this page https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html , which is helpful for me as I need to extract internal data in a trained decision tree.
But, one question popped up. For each node, there are threshold
value and feature
value.
The threshold is fine. For a test phase where a feature vector (from test data) is taken as input to the tree and one of the features is mapped to a node which we compare the feature (from test data) and the threshold.
What exactly is the feature
(from training data) in the trained tree? The following is the code snippet.
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)
n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
# This is an array where one feature value
# is associated with each node in the tree trained.
# What's the meaning of the feature for each node
# in the trained tree?
feature = clf.tree_.feature
threshold = clf.tree_.threshold
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
# what this shows is `[ 3 -2 2 -2 -2]`,
# where the 1st, 3rd, 4th nodes are leaves
# and associated with -2.
# What are 3 and 2 on the other split node?
# How were these values determined?
print(feature)
The dimension of the feature vector in this case is 4, and there are 5 nodes including both leaf and non-leaf nodes in the tree.
The feature
is [ 3 -2 2 -2 -2], where everything but the 0-th and 2nd is leaf node. Non-leaf node is associated with values 2 or 3. What's the meaning of this?
Does this mean that for a feature vector (from test data) x=(x0, x1, x2, x3), we use x3 on the 0-th node and perform comparison with its threshold whereas we use x2 on the 2nd node and perform comparison with its threshold?