I have a classifier that I trained using Python's scikit-learn. How can I use the classifier from a Java program? Can I use Jython? Is there some way to save the classifier in Python and load it in Java? Is there some other way to use it?
6 Answers
You cannot use jython as scikit-learn heavily relies on numpy and scipy that have many compiled C and Fortran extensions hence cannot work in jython.
The easiest ways to use scikit-learn in a java environment would be to:
expose the classifier as a HTTP / Json service, for instance using a microframework such as flask or bottle or cornice and call it from java using an HTTP client library
write a commandline wrapper application in python that reads data on stdin and output predictions on stdout using some format such as CSV or JSON (or some lower level binary representation) and call the python program from java for instance using Apache Commons Exec.
make the python program output the raw numerical parameters learnt at fit time (typically as an array of floating point values) and reimplement the predict function in java (this is typically easy for predictive linear models where the prediction is often just a thresholded dot product).
The last approach will be a lot more work if you need to re-implement feature extraction in Java as well.
Finally you can use a Java library such as Weka or Mahout that implement the algorithms you need instead of trying to use scikit-learn from Java.

- 39,309
- 12
- 116
- 125
-
3One of my coworkers just suggested Jepp...is that something that would work for this? – Thomas Johnson Oct 05 '12 at 13:44
-
Probably, I did not know about jepp. It indeed looks suited for the task. – ogrisel Oct 05 '12 at 14:11
-
For a web app, I personally like the http exposure approach better. @user939259 could then use a classifier pool for various apps and scale it more easily (sizing the pool according to demand). I'd only consider Jepp for a desktop app. As much a python lover as I am, unless scikit-lear has significantly better performance than Weka or Mahout, I'd go for a single-language solution. Having more than one language/framework should be considered technical debt. – rbanffy Oct 06 '12 at 18:05
-
I agree about the multilanguage technical debt: it's hard to work in a team were all devs know both java and python and having to switch from one technical culture to the other adds useless complexity in the management of the project. – ogrisel Oct 07 '12 at 15:40
-
Maybe it is technical debt - but to stretch the metaphor, in machine learning you're constantly declaring bankruptcy anyways because you're trying stuff out, finding it doesn't work, and tweaking it / throwing it away. So maybe the debt isn't as big a deal in a case like that. – Thomas Johnson Oct 07 '12 at 22:24
There is JPMML project for this purpose.
First, you can serialize scikit-learn model to PMML (which is XML internally) using sklearn2pmml library directly from python or dump it in python first and convert using jpmml-sklearn in java or from a command line provided by this library. Next, you can load pmml file, deserialize and execute the loaded model using jpmml-evaluator in your Java code.
This way works with not all scikit-learn models, but with many of them.
As some commenters correctly pointed out, it's important to note that JPMML project is licensed under GNU AGPL. AGPL is a strong copyleft license, which may limit your ability to use the project. One of the examples may be if you develop a publically accessible service and want to keep the sources closed.

- 5,379
- 1
- 26
- 40
-
2How do you ensure that the feature transformation part is consistent between the one done in Python for training and the one done in Java (using pmml) for serving? – Andrea Bergonzo Sep 29 '17 at 19:38
-
2I tried this, and it definitely works for converting sklearn transformers and xgboost model to Java. However, we didn't choose this in our production environment because of the AGPL license. (There is also a commercial license, but negotiating a license does not fit our project timeline.) – leon Oct 16 '19 at 05:52
-
1I tried this, kept all the feature extraction,cleaning,transformation logic through Java program. And it works fine on the Java side (jpmml-evaluator). A good option for containerized Spring boot application, greatly reducing the devops complexity as the frequency and timeline of the python training cannot be synchronized with continuous integration of Java program – Indrajit Kanjilal Aug 06 '20 at 07:32
-
@leon's comment is super important, especially for people who copy/paste solutions from SO answers as a significant part of their software development lifecycle. *If you use jpmml-evaluator in your product, your users could force you to disclose all the source code to your product.* This is the Big Bad Wolf that Microsoft was warning people about when they equated all Open Source Software to libraries licensed under GPL (*not* LGLP) and similar licenses. Always read your licenses! – Christopher Schultz Aug 05 '22 at 18:59
You can either use a porter, I have tested the sklearn-porter (https://github.com/nok/sklearn-porter), and it works well for Java.
My code is the following:
import pandas as pd
from sklearn import tree
from sklearn_porter import Porter
train_dataset = pd.read_csv('./result2.csv').as_matrix()
X_train = train_dataset[:90, :8]
Y_train = train_dataset[:90, 8:]
X_test = train_dataset[90:, :8]
Y_test = train_dataset[90:, 8:]
print X_train.shape
print Y_train.shape
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)
porter = Porter(clf, language='java')
output = porter.export(embed_data=True)
print(output)
In my case, I'm using a DecisionTreeClassifier, and the output of
print(output)
is the following code as text in the console:
class DecisionTreeClassifier {
private static int findMax(int[] nums) {
int index = 0;
for (int i = 0; i < nums.length; i++) {
index = nums[i] > nums[index] ? i : index;
}
return index;
}
public static int predict(double[] features) {
int[] classes = new int[2];
if (features[5] <= 51.5) {
if (features[6] <= 21.0) {
// HUGE amount of ifs..........
}
}
return findMax(classes);
}
public static void main(String[] args) {
if (args.length == 8) {
// Features:
double[] features = new double[args.length];
for (int i = 0, l = args.length; i < l; i++) {
features[i] = Double.parseDouble(args[i]);
}
// Prediction:
int prediction = DecisionTreeClassifier.predict(features);
System.out.println(prediction);
}
}
}

- 61
- 1
- 4
-
thanks for the info. Can you share your ideas on how to execute a sklearn model pickled using sklearn porter, and use it for prediction in Java - @gustavoresque – Sourav Saha Apr 21 '20 at 07:52
Here is some code for the JPMML solution:
--PYTHON PART--
# helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
def determine_categorical_columns(df):
categorical_columns = []
x = 0
for col in df.dtypes:
if col == 'object':
val = df[df.columns[x]].iloc[0]
if not isinstance(val,Decimal):
categorical_columns.append(df.columns[x])
x += 1
return categorical_columns
categorical_columns = determine_categorical_columns(df)
other_columns = list(set(df.columns).difference(categorical_columns))
#construction of transformators for our example
labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
nones = [(d, None) for d in other_columns]
transformators = labelBinarizers+nones
mapper = DataFrameMapper(transformators,df_out=True)
gbc = GradientBoostingClassifier()
#construction of the pipeline
lm = PMMLPipeline([
("mapper", mapper),
("estimator", gbc)
])
--JAVA PART --
//Initialisation.
String pmmlFile = "ScikitLearnNew.pmml";
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
//Determine which features are required as input
HashMap<String, Field>() inputFieldMap = new HashMap<String, Field>();
for (int i = 0; i < evaluator.getInputFields().size();i++) {
InputField curInputField = evaluator.getInputFields().get(i);
String fieldName = curInputField.getName().getValue();
inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
}
//prediction
HashMap<String,String> argsMap = new HashMap<String,String>();
//... fill argsMap with input
Map<FieldName, ?> res;
// here we keep only features that are required by the model
Map<FieldName,String> args = new HashMap<FieldName, String>();
Iterator<String> iter = argsMap.keySet().iterator();
while (iter.hasNext()) {
String key = iter.next();
Field f = inputFieldMap.get(key);
if (f != null) {
FieldName name =f.getName();
String value = argsMap.get(key);
args.put(name, value);
}
}
//the model is applied to input, a probability distribution is obtained
res = evaluator.evaluate(args);
SegmentResult segmentResult = (SegmentResult) res;
Object targetValue = segmentResult.getTargetValue();
ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;

- 380
- 3
- 16
I found myself in a similar situation. I'll recommend carving out a classifier microservice. You could have a classifier microservice which runs in python and then expose calls to that service over some RESTFul API yielding JSON/XML data-interchange format. I think this is a cleaner approach.

- 911
- 1
- 10
- 21
Alternatively you can just generate a Python code from a trained model. Here is a tool that can help you with that https://github.com/BayesWitnesses/m2cgen

- 321
- 2
- 10