2

I am new at working with Scikit Learn, machine learning, with Python. I was trying to work with a decision tree. I managed to do all the cleaning of the data, analysis and so on until I tried to get the decision tree diagram.

I am working with Python 3.4 and pyplot2. I have a function called decision_tree that makes the model and then called a function (plot_classifier) giving (clf) to draw it with this lines:

    dot_data = StringIO()
    export_graphviz(clf, out_file=dot_data)
    print(type(dot_data.getvalue()))
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    Image(graph.create_png())

This code is like the Scikit lean ones. The problem is in the marked line. I have fit my model and check the results. They were ok, but I can't figure out how to draw the tree. I get this on the console

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-66-900e97a70715> in <module>()
----> 1 decision_tree(feature_train, feature_test, label_train, label_test)

<ipython-input-3-c2e245caf6c8> in decision_tree(feature_train, feature_test,     label_train, label_test)
 17     print("Cantidad de aciertos: " + str(count) + "\n Cantidad de Elementos: " + str(len(pred)))
 18     print(no_match)
---> 19     plot_classifier(clf)
20     return accuracy
21 

<ipython-input-65-78f8ba2dc1ed> in plot_classifier(clf)
  3         export_graphviz(clf, out_file=dot_data)
  4         print(type(dot_data.getvalue()))
----> 5         graph = pydot.graph_from_dot_data(dot_data.getvalue())
  6         Image(graph.create_png())
  7 

C:\Anaconda3\lib\site-packages\pydot.py in graph_from_dot_data(data)
218     """
219 
--> 220     return dot_parser.parse_dot_data(data)
221 
222 

C:\Anaconda3\lib\site-packages\dot_parser.py in parse_dot_data(data)
508     top_graphs = list()
509 
--> 510     if data.startswith(codecs.BOM_UTF8):
511         data = data.decode( 'utf-8' )
512 

TypeError: startswith first arg must be str or a tuple of str, not bytes

I check on the internet the error msj, but the answers talk about the startswith line (that's on the library, and i don't think there will be the problem when a lot of people have this working). I check for problems on the other lines and can't find the one I have the problem either.

Can anyone help me with this? I tried things like converting to string or tuple (even thou getvalues() already return a string) but nothing.

JasonMArcher
  • 14,195
  • 22
  • 56
  • 52
ibarrau
  • 388
  • 1
  • 10

1 Answers1

0

I had the same problem on Linux and solved installing pydotplus and its dependencies.

https://pypi.python.org/pypi/pydotplus

After installing pydotplus and its dependencies:

import pydotplus

graph = pydotplus.pydotplus.graph_from_dot_data(dot_data.getvalue())