1

I am trying to generate visualisation of decision tree. However, I am getting an error that I cannot resolve. This is my code:

from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pydotplus

feature_cols = ['Reason_for_absence', 'Month_of_absence']
feature_cols
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols,class_names['0', '1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree.png')
Image(graph.create_png())

I am getting the following error:

File "", line 9
    export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols,class_names['0', '1'])
                                                                                                                            ^
SyntaxError: positional argument follows keyword argument

EDIT:

I have change the code according to the answer and now I am getting error:

IndexError: list index out of range

While the code was a amended a bit:

feature_cols = ['Reason_for_absence',
 'Month_of_absence',
 'Day_of_the_week',
 'Seasons',
 'Transportation_expense',
 'Distance_from_Residence_to_Work',
 'Service_time',
 'Age',
 'Work_load_Average/day ',
 'Hit_target',
 'Disciplinary_failure',
 'Education',
 'Son',
 'Social_drinker',
 'Social_smoker',
 'Pet',
 'Weight',
 'Height',
 'Bod_mass_index',
 'Absenteeism']
dot_data = StringIO()

export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols, class_names=['0', '1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree.png')
Image(graph.create_png())
Matt
  • 166
  • 1
  • 2
  • 18

1 Answers1

1

You were missing with a =, you should update the last argument to class_names=['0', '1']:

export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, 
    special_characters=True, 
    feature_names = feature_cols, 
    class_names=['0', '1'])
Paul Lo
  • 6,032
  • 6
  • 31
  • 36
  • Now, I am getting list index out of range. – Matt Dec 11 '19 at 02:01
  • @HalfMartianHalfHuman That might have something to do with your `y`, if there are more than 2 types of outcome (multi-classification rather than binary), you need to provide the same number of class_names – Paul Lo Dec 11 '19 at 02:09
  • Actually, I have removed `class_names=['0', '1'])` from `feature_names` and the visualisation was generated – Matt Dec 11 '19 at 02:15