UPDATE: I tried using the following way to generate confidence scores but its giving me an exception. I use the below code snippet:
double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0 / (1.0 + Math.exp(-point));
And the exception I get:
Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)
Can you please help? It seems like the weights vector has more elements (198) than the data vector (I am generating 18 features). They have to be of same length in the dot()
function.
I am trying to implement a program in Java to train from an existing dataset and predict on a new dataset using the logistic regression algorithm available in Spark MLLib (1.5.0). My train and predict programs are as below and I am using a multiclass implementation. The problem is when I do a model.predict(vector)
(notice the lrmodel.predict() in the predict program), I get the predicted label. But what if I need a confidence score? How do I get that? I have gone through the API and was unable to locate any specific API giving the confidence score. Can anyone please help me out?
Train Program (generates a .model file)
public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;
try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Train").setMaster(
sparkMaster);
jsc = new JavaSparkContext(sparkConf);
...
JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();
final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;
public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});
...
JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;
public Object call(final Object arg0)
throws Exception {
...
StringBuilder featureSet =
new StringBuilder();
...
featureSet.append(i - 2);
featureSet.append(COLON);
featureSet.append(strVal);
featureSet.append(SPACE);
}
return featureSet.toString().trim();
}
});
List<String> featureList = featureRDD.collect();
String featureOutput = StringUtils.join(featureList, NEW_LINE);
String filePath = basePath + "lr.arff";
FileUtils.writeStringToFile(new File(filePath), featureOutput,
"UTF-8");
JavaRDD<LabeledPoint> trainingData =
MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();
final LogisticRegressionModel model =
new LogisticRegressionWithLBFGS().setNumClasses(18).run(
trainingData.rdd());
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(model);
oos.flush();
oos.close();
FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
baos.toByteArray());
baos.close();
} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}
Predict program (using the lr.model generated from the train program)
public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;
try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
jsc = new JavaSparkContext(sparkConf);
ObjectInputStream objectInputStream =
new ObjectInputStream(new FileInputStream(basePath
+ "lr.model"));
LogisticRegressionModel lrmodel =
(LogisticRegressionModel) objectInputStream.readObject();
objectInputStream.close();
...
JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();
final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;
public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});
...
final Broadcast<LogisticRegressionModel> broadcastModel =
jsc.broadcast(lrmodel);
JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;
public Object call(final Object arg0)
throws Exception {
...
LogisticRegressionModel lrModel =
broadcastModel.value();
String row = ((String) arg0);
String[] featureSetArray =
row.split(CSV_SPLITTER);
...
final Vector vector =
Vectors.dense(doubleArr);
double score = lrModel.predict(vector);
...
return csvString;
}
});
String outputContent =
featureRDD
.reduce(new org.apache.spark.api.java.function.Function2() {
private static final long serialVersionUID =
1212970144641935082L;
public Object call(Object arg0, Object arg1)
throws Exception {
...
}
});
...
FileUtils.writeStringToFile(new File(basePath
+ "predicted-sales-data.csv"), sb.toString());
} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}
}
}