1

I have managed to get the BERT model to work on johnsnowlabs-spark-nlp library. I am able to save the "trained model" on disk as follows.

Fit Model

df_bert_trained = bert_pipeline.fit(textRDD)

df_bert=df_bert_trained.transform(textRDD)

save model

df_bert_trained.write().overwrite().save("/home/XX/XX/trained_model")

However,

First, as per the docs here https://nlp.johnsnowlabs.com/docs/en/concepts, it's stated that one can load the model as

EmbeddingsHelper.load(path, spark, format, reference, dims, caseSensitive) 

but it's unclear to me what the variable "reference" represents at this point.

Second, has anyone managed to save the BERT embeddings as a pickle file in python?

user8291021
  • 326
  • 2
  • 9

1 Answers1

1

In Spark NLP, BERT comes as a pre-trained model. It means it's already a model that was trained, fit, etc. and saved in the right format.

That's being said, there is no reason to fit or save it again. You can, however, save the result of it once you transform your DataFrame to a new DataFrame that has BERT embeddings for each token.

Example:

Start a Spark Session in spark-shell with Spark NLP package

spark-shell --packages JohnSnowLabs:spark-nlp:2.4.0
import com.johnsnowlabs.nlp.annotators._
import com.johnsnowlabs.nlp.base._

val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setOutputCol("document")

    val sentence = new SentenceDetector()
      .setInputCols("document")
      .setOutputCol("sentence")

    val tokenizer = new Tokenizer()
      .setInputCols(Array("sentence"))
      .setOutputCol("token")

    // Download and load the pretrained BERT model
    val embeddings = BertEmbeddings.pretrained(name = "bert_base_cased", lang = "en")
      .setInputCols("sentence", "token")
      .setOutputCol("embeddings")
      .setCaseSensitive(true)
      .setPoolingLayer(0)

    val pipeline = new Pipeline()
      .setStages(Array(
        documentAssembler,
        sentence,
        tokenizer,
        embeddings
      ))

// Test and transform

   val testData = Seq(
      "I like pancakes in the summer. I hate ice cream in winter.",
      "If I had asked people what they wanted, they would have said faster horses"
    ).toDF("text")

    val predictionDF = pipeline.fit(testData).transform(testData)

The predictionDF is a DataFrame that contains BERT embeddings for each token inside your dataset. The BertEmbeddings pre-trained models are coming from TF Hub, which means they are the exact same pre-trained weights published by Google. All 5 models are available:

  • bert_base_cased (en)
  • bert_base_uncased (en)
  • bert_large_cased (en)
  • bert_large_uncased (en)
  • bert_multi_cased (xx)

Let me know if you have any questions or problems and I'll update my answer.

References:

Maziyar
  • 1,913
  • 2
  • 18
  • 37