10

I'm building a decision tree using Scikit-Learn in Python. I've trained the model on a particular dataset and now I want to save this decision tree so that it can be used later (on a new dataset). Anyone knows how to do this?

sat63k
  • 333
  • 1
  • 2
  • 13
nEO
  • 5,305
  • 3
  • 21
  • 25

3 Answers3

8

As taken from the Model Persistence section of this tutorial:

It is possible to save a model in the scikit by using Python’s built-in persistence model, namely pickle:

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0])
array([0])
>>> y[0]
0
Matthew Spencer
  • 2,265
  • 1
  • 23
  • 28
  • That's cool. So when I save the DT into the pickle file, it stores all the structure and can be re-used again. That's really nice. Anything for R too? – nEO Oct 02 '14 at 13:24
  • Got it for what to do in R to save robjects. Once can you the save command to save the objects to a file and then use the load command to read back. More details about the save command can be found [link](https://stat.ethz.ch/R-manual/R-devel/library/base/html/save.html) Info about load command can be found at [link](https://stat.ethz.ch/R-manual/R-devel/library/base/html/load.html) – nEO Oct 02 '14 at 14:05
5

There is currently no reliable way of doing this. While pickling does work, it is not good enough, as your pickled data is not guaranteed to get properly unpickled with a later version of scikit-learn.

Quote from: http://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations

Models saved in one version of scikit-learn might not load in another version.

Bastian Venthur
  • 12,515
  • 5
  • 44
  • 78
4

I used joblib as below:

>>> from joblib import dump, load
>>> dump(clf, 'filename.joblib')
>>> clf = load('filename.joblib')

However, need to consider these security and maintainability limitations.

sat63k
  • 333
  • 1
  • 2
  • 13