0

I have a function to display my classification tree as an ipywidgets as below:

labels = X.columns

#functions to create and fit the decision tree
def plot_tree(criterion, depth, split, min_split, min_leaf=0.2):
    #create tree
    clf = DecisionTreeClassifier(criterion = criterion, max_depth = depth, 
                                       splitter = split, min_samples_split=min_split, min_samples_leaf=min_leaf)
    clf.fit(X, y)

    #create graph
    graph = Source(tree.export_graphviz(clf, out_file=None, feature_names=labels, filled=True))

    display(SVG(graph.pipe(format='svg')))

    return clf

#create buttons for interactive graph
inter = interactive(plot_tree, criterion=["gini", "entropy"], depth=np.arange(1, 12),
                   split = ["best", "random"], min_split=(0.1,1), min_leaf=(0.1,0.5))

display(inter)

and a way to save the tree as a png:

png_bytes = graph.pipe(format='png')
with open('../graph/tree.png','wb') as f:
    f.write(png_bytes)
Image(png_bytes);

The problem is that if I put the save part as a function just after 'def plot_tree', I'll get the png as the first display with only 1 depth of the tree. And if I add the save part in an another cell, I'll get the whole tree saved.

I want to add a button to the widget that will save the display tree with all the parameters set displayed. How can i do that?

lela_rib
  • 147
  • 2
  • 10

0 Answers0