5

It is possible to visualize decision trees using pydotplus from pypi, but it has issues on my machine (it says it was not build with libexpat and thus it only shows a number on a node instead of a table with some information) and I'd like to use an alternative. I already tried using networkx, but it requires pygraphviz to read .dot files and make a networkx graph of them. When I tried to install it using pip that also failed.

So now I am looking for an alternative way of visualizing decision trees, which can be installed using pip or anaconda.

Which alternatives exist?

EDIT#1

Output of conda list:

# packages in environment at /home/xiaolong/development/anaconda3/envs/coursera_ml_classification:
#
alabaster                 0.7.7                    py34_0    defaults
awscli                    1.6.2                     <pip>
babel                     2.3.3                    py34_0    defaults
backports                 1.0                      py34_0    defaults
backports-abc             0.4                       <pip>
backports.shutil-get-terminal-size 1.0.0                     <pip>
backports_abc             0.4                      py34_0    defaults
bcdoc                     0.12.2                    <pip>
boto                      2.33.0                    <pip>
botocore                  0.73.0                    <pip>
cairo                     1.12.18                       6    defaults
certifi                   2015.4.28                 <pip>
colorama                  0.2.5                     <pip>
cycler                    0.10.0                   py34_0    defaults
decorator                 4.0.9                    py34_0    defaults
docutils                  0.12                     py34_0    defaults
entrypoints               0.2                      py34_1    defaults
expat                     2.1.0                         0    defaults
fontconfig                2.11.1                        5    defaults
freetype                  2.5.5                         0    defaults
get_terminal_size         1.0.0                    py34_0    defaults
glib                      2.43.0                        2    asmeurer
graphviz                  2.38.0                        1    defaults
harfbuzz                  0.9.39                        0    defaults
imagesize                 0.7.0                    py34_0    defaults
ipykernel                 4.3.1                    py34_0    defaults
ipython                   4.2.0                    py34_0    defaults
ipython-genutils          0.1.0                     <pip>
ipython_genutils          0.1.0                    py34_0    defaults
ipywidgets                4.1.1                    py34_0    defaults
jedi                      0.9.0                    py34_0    defaults
jinja2                    2.8                      py34_0    defaults
jmespath                  0.5.0                     <pip>
jsonschema                2.5.1                    py34_0    defaults
jupyter                   1.0.0                    py34_2    defaults
jupyter-client            4.2.2                     <pip>
jupyter-console           4.1.1                     <pip>
jupyter-core              4.1.0                     <pip>
jupyter_client            4.2.2                    py34_0    defaults
jupyter_console           4.1.1                    py34_0    defaults
jupyter_core              4.1.0                    py34_0    defaults
libffi                    3.2.1                         0    defaults
libgcc                    5.2.0                         0    defaults
libgfortran               3.0.0                         1    defaults
libpng                    1.6.17                        0    defaults
libsodium                 1.0.3                         0    defaults
libxml2                   2.9.2                         0    defaults
llvmlite                  0.10.0                   py34_0    defaults
markupsafe                0.23                     py34_0    defaults
matplotlib                1.5.1               np111py34_0    defaults
mistune                   0.7.2                    py34_0    defaults
mkl                       11.3.1                        0    defaults
multipledispatch          0.4.8                     <pip>
nbconvert                 4.2.0                    py34_0    defaults
nbformat                  4.0.1                    py34_0    defaults
notebook                  4.2.0                    py34_0    defaults
numpy                     1.11.0                   py34_0    defaults
openssl                   1.0.2h                        0    defaults
pandas                    0.18.1              np111py34_0    defaults
pango                     1.39.0                        0    defaults
path.py                   8.2.1                    py34_0    defaults
pep8                      1.7.0                    py34_0    defaults
pexpect                   4.0.1                    py34_0    defaults
pickleshare               0.5                      py34_0    defaults
pip                       8.1.1                    py34_1    defaults
pixman                    0.32.6                        0    defaults
prettytable               0.7.2                     <pip>
psutil                    4.1.0                    py34_0    defaults
ptyprocess                0.5                      py34_0    defaults
pyasn1                    0.1.9                     <pip>
pydotplus                 2.0.2                    py34_0    file:///home/xiaolong/development/anaconda3/conda-bld/linux-64/pydotplus-2.0.2-py34_0.tar.bz2
pyflakes                  1.1.0                    py34_0    defaults
pygments                  2.1.3                    py34_0    defaults
pyparsing                 2.1.1                    py34_0    defaults
pyqt                      4.11.4                   py34_1    defaults
python                    3.4.4                         0    defaults
python-contrib-nbextensions alpha                     <pip>
python-dateutil           2.5.2                    py34_0    defaults
pytz                      2016.3                   py34_0    defaults
pyyaml                    3.11                      <pip>
pyzmq                     15.2.0                   py34_0    defaults
qt                        4.8.7                         1    defaults
qtconsole                 4.2.1                    py34_0    defaults
readline                  6.2                           2    defaults
requests                  2.9.1                     <pip>
rope                      0.9.4                    py34_1    defaults
rope-py3k                 0.9.4.post1               <pip>
rsa                       3.1.2                     <pip>
scikit-learn              0.17.1              np111py34_0    defaults
scipy                     0.17.0              np111py34_3    defaults
setuptools                20.7.0                   py34_0    defaults
sframe                    1.8.5                     <pip>
simplegeneric             0.8.1                    py34_0    defaults
sip                       4.16.9                   py34_0    defaults
six                       1.10.0                   py34_0    defaults
snowballstemmer           1.2.1                    py34_0    defaults
sphinx                    1.4.1                    py34_0    defaults
sphinx-rtd-theme          0.1.9                     <pip>
sphinx_rtd_theme          0.1.9                    py34_0    defaults
spyder                    2.3.8                    py34_1    defaults
sqlite                    3.9.2                         0    defaults
terminado                 0.5                      py34_1    defaults
tk                        8.5.18                        0    defaults
tornado                   4.3                      py34_0    defaults
traitlets                 4.2.1                    py34_0    defaults
wheel                     0.29.0                   py34_0    defaults
xz                        5.0.5                         1    defaults
zeromq                    4.1.3                         0    defaults
zlib                      1.2.8                         0    defaults

SciPy version: 0.17.0

digraph Tree {
node [shape=box, style="filled", color="black"] ;
0 [label="grade.B <= 0.5\ngini = 0.5\nsamples = 37224\nvalue = [18476, 18748]", fillcolor="#399de504"] ;
1 [label="grade.C <= 0.5\ngini = 0.4973\nsamples = 32094\nvalue = [17218, 14876]", fillcolor="#e5813923"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.4829\nsamples = 21728\nvalue = [12875, 8853]", fillcolor="#e5813950"] ;
1 -> 2 ;
3 [label="gini = 0.4869\nsamples = 10366\nvalue = [4343, 6023]", fillcolor="#399de547"] ;
1 -> 3 ;
4 [label="grade.A <= 14.8301\ngini = 0.3702\nsamples = 5130\nvalue = [1258, 3872]", fillcolor="#399de5ac"] ;
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
5 [label="gini = 0.3555\nsamples = 4987\nvalue = [1153, 3834]", fillcolor="#399de5b2"] ;
4 -> 5 ;
6 [label="gini = 0.3902\nsamples = 143\nvalue = [105, 38]", fillcolor="#e58139a3"] ;
4 -> 6 ;
}

EDIT#2

I programmed this in a Jupyter notebook, but that has a bug of not coloring the svg if you try to display the SVG using:

![Decision Tree]('dtree.svg')

I found a work-around here:

from IPython.display import HTML

svg = None
with open('dtree.svg') as svg_file:
    svg = svg_file.read()

HTML(svg)
Terence Parr
  • 5,912
  • 26
  • 32
Zelphir Kaltstahl
  • 5,722
  • 10
  • 57
  • 86

2 Answers2

6

It's not the sexiest solution but I use the Grapviz CLI (it's called dot) called via subprocess, I'm on Mac, so I installed it with homebrew, but you can download binaries for other platforms from their downloads page. Here's an example using the Titanic datset:

import pandas as pd
import subprocess
import seaborn.apionly as sns
fromwd sklearn.preprocessing import Imputer
from sklearn.tree import DecisionTreeClassifier, export_graphviz

raw_data = sns.load_dataset('titanic')
predictors = ['pclass','sex','age','sibsp','parch','fare','embarked','alone','adult_male']
categorical = ['sex','embarked']
numeric = [c for c in predictors if c not in categorical]
target='survived'

encoded_data = pd.get_dummies(raw_data[predictors], columns=categorical)

imputer = Imputer()
X = imputer.fit_transform(encoded_data).astype('float32')
Y = raw_data[target].astype('float32')

model = DecisionTreeClassifier(min_samples_leaf=10, max_depth=3)
model.fit(X, Y)

export_graphviz(model,
                out_file='tree.dot',
                feature_names=encoded_data.columns,
                proportion=True,
                filled=True,
                impurity=False)

subprocess.call(['dot', '-Tpdf', 'tree.dot', '-o' 'tree.pdf'])
maxymoo
  • 35,286
  • 11
  • 92
  • 119
  • ~~Tried and working. However, the _coloring is missing_ from the graph. Is there an easy fix for that?~~ Nvm, the coloring is coming from `pydotplus`, which in turn cannot render the labels correctly, so it's pick your choice for me. – Zelphir Kaltstahl May 12 '16 at 11:58
  • i'm getting colors on mine, what platform are you on? – maxymoo May 12 '16 at 22:54
  • I am on Linux, but the thing is that `pydotplus` also gives me a warning, that it has not been build with `libexpat` and I didn't find a way of fixing that and still installing it in a `virtualenv` using `pip` or better `anaconda`. So without libexpat it cannot display tabular data as labels of nodes. However, if I use the `dot` tool itself like suggested in the answer, there seems to be no color information in my dot files, so it seems only logical, that there is none in any graphic generated from the dot file. Did I go wrong somewhere? Colors _would_ be nice after all : ) – Zelphir Kaltstahl May 13 '16 at 02:00
  • are you using the latest version of sklearn? there's color info in my dot file, the first node is `[label="adult_male <= 0.5\nsamples = 100.0%\nvalue = [0.62, 0.38]", fillcolor="#e5813960"] ;` – maxymoo May 13 '16 at 02:05
  • I revised my code again and what do you know, I forgot to add the `filled=True`. However, even with the `fillcolor` values in the dot file and the colors in the output SVG I create, the notebook still shows a black and white SVG. Maybe it's a browser issue with colored SVG or maybe a Jupyter notebook issue. – Zelphir Kaltstahl May 13 '16 at 14:40
  • Found a work-around for the notebook issue. Thanks for stating that you have color on your end, I might not have checked this again! Now I have exactly the visualization I wanted. – Zelphir Kaltstahl May 13 '16 at 15:01
  • glad you've got it working . how are you displaying the svg in the notebook .. did you have to manually link it? – maxymoo May 18 '16 at 23:16
  • 1
    Check out the OP, or this: https://github.com/scikit-learn/scikit-learn/issues/6522 : ) – Zelphir Kaltstahl May 19 '16 at 13:52
4

From version 0.21 scikit-learn has plot_tree method which plot tree with matplotlib.

The code to use plot_tree:

from sklearn import tree
# the clf is Decision Tree object
tree.plot_tree(clf,feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

The alternative to sklearn plots can be dtreeviz package. The example of the tree is below. The code to use dtreeviz:

from dtreeviz.trees import dtreeviz # remember to load the package
# the clf is Decision Tree object
viz = dtreeviz(clf, X, y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=list(iris.target_names))

viz

You can find a comparison of different scikit-learn tree plotting techniques here.

dtreeviz decision tree visualization

pplonski
  • 5,023
  • 1
  • 30
  • 34