4

I am trying to have Graphviz display my oneHotEncoded categorical data but I can't get it to work.

Here is my X data with theses columns:

Category, Size, Type, Rating, Genre, Number of versions   

['ART_AND_DESIGN' '6000000+' 'Free' 'Everyone' 'Art & Design' '7']  
['ART_AND_DESIGN' '6000000+' 'Free' 'Everyone' 'Art & Design' '2']  

...   
['FAMILY' '20000000+' 'Free' 'Everyone' 'Art & Design' '13']

And my code sample:

X = self.df.drop(['Installs'], axis=1).values
y = self.df['Installs'].values

self.oheFeatures = OneHotEncoder(categorical_features='all')
EncodedX = self.oheFeatures.fit_transform(X).toarray()

self.oheY = OneHotEncoder()
EncodedY = self.oheY.fit_transform(y.reshape(-1,1)).toarray()

self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(EncodedX, EncodedY, test_size=0.25, random_state=33)

clf = DecisionTreeClassifier(criterion='entropy', min_samples_leaf=100)
clf.fit(self.X_train, self.y_train)
    
tree.export_graphviz(clf, out_file=None, 
            feature_names=self.oheFeatures.get_feature_names(),
            class_names=self.oheY.get_feature_names(),
            filled=True, 
            rounded=True,  
            special_characters=True)  

Dot_data = tree.export_graphviz(clf, out_file=None) 
    graph = graphviz.Source(dot_data) 
    graph.render("applications") 

But when I try to visualize the output result, I get the decision tree of the encoded data:

output from graphviz

Is there any way to have graphviz display the "decoded" data instead?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
PM Laforest
  • 100
  • 1
  • 1
  • 11

2 Answers2

2

You seem to expect that there is some (relatively simple) programming recipe or workaround here, while this is far from being the case and the issue actually goes far deeper than a simple Graphviz visualization.

What you have to keep in mind are:

  • Your tree does not "know" anything about any decoded data; as far as it is concerned, the only actual data are the one-hot encoded ones
  • Graphviz does nothing more than simply displaying the tree structure

Given the above, you may be able to see that, although your request sounds indeed meaningful and understood as a desired feature from a "business" perspective serving the tree interpretation (after all, one of the great advantages of tree models is supposed to be exactly their interpretability), the issue is highly non-trivial from a programming perspective (which SO is actually all about): it does not even have anything to do with Graphviz in particular - the issue rises even if we would like to print the tree in the form of rules, and it has all to do with the design choices made by scikit-learn for the tree building.

I'll confess that, when I first heard about it, I was myself surprised to learn that decision trees in scikit-learn cannot directly handle categorical variables (see the discussion in Can sklearn DecisionTreeClassifier truly work with categorical data? and a still open Github issue on the subject). And the reason I was surprised is that I knew that certainly this is not the only design choice available: I come from an R background, and I knew that in R it is perfectly possible to fit trees with categorical features without the necessity of one-hot encoding; here is an example:

enter image description here

But alas, this is not available for scikit-learn at least for the time being (I highly doubt that it will change)...

desertnaut
  • 57,590
  • 26
  • 140
  • 166
0

@desertnaut is correct that there isn't a quick and easy way to do this, because the model inside sklearn treats the binary dummy variables exactly the same as any other real-valued feature.

But, at least in your simple case (where all the features are one-hot encoded), this isn't too hard to make work. First, you can provide input_feature_names to the get_feature_names so that the feature names are more useful than X[34]. Next, the output of export_graphviz is DOT code, which is human-readable and therefore human-editable. In a very small example like yours, you could do it entirely by hand; in larger examples, you might want to make use of regex replacements or something similar.

I put together a notebook to demonstrate this, once in the very simple case and once in a somewhat more-complicated case. I thought about monkey-patching parts of the export_graphviz methods, but ended up instead just modifying the DOT code after the fact. When sklearn finishes deciding on and implementing an approach to preserving feature names (or at least finishes fleshing out the get_feature_names methods for all transformers), the second example should work for significantly more complex transformer pipelines/composites.

Ben Reiniger
  • 10,517
  • 3
  • 16
  • 29