1

I am experimenting with spark.ml library and the the pipelines capability. There seems to be a limitation in using SQL with splits (e.g. for train and test):

  • It is nice that spark.ml works off of schema rdd, but there is no easy way to randomly split schema rdd in test and train set. I can use randomSplit(0.6,0.4) but that gives back an array of RDD that loses the schema. I can force a case class on it and covert it back to schema RDD, but I have a lot of features. I used filter and used some basic partitioning condition based on one of my iid feature). Any suggestions of what else can be done?

Regarding the generated model:

  • How do I access the model weights? The lr optimizer and lr model internally has weights but it is unclear how to us them.
WestCoastProjects
  • 58,982
  • 91
  • 316
  • 560
charmee
  • 1,501
  • 2
  • 9
  • 9
  • Seems like a reasonable pair of questions: I work on mllib and do understand them (though do not have an answer yet). Given the close votes I am going to edit the question to see if others can then agree. – WestCoastProjects Feb 06 '15 at 21:28

1 Answers1

2

Ok, For 2nd part of the question,

How do I access the model weights? The lr optimizer and lr model internally has weights but it is unclear how to use them

After going through the source of the library(with non-existing Scala knowledge),

The LogisticRegressionModel(of spark.ml) has attribute weights (of type vector).

Case 1

If you have the LogisticRegressionModel (of spark.ml)

LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel lr1 = lr.fit(df_train);
System.out.println("The weights are  " + lr1.weights())

Case 2

If you have the Pipeline Model, first get the LogisticRegressionModel (Transformer) by using getModel

    LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01);
    Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { lr });

    PipelineModel model = pipeline.fit(train_df);
    LogisticRegressionModel lrModel =model.getModel(lr);
    System.out.println("The model is  {}", lrm.weights());

If it is incorrect or there is a better way, do let me know.

Ankitp
  • 747
  • 6
  • 12
  • Thanks. The question I had asked was on Spark version 1.2.0. As of the latest version (1.3.1) getting weights of a model is fairly easy as you described. – charmee Jun 01 '15 at 15:04