2

I trained a language model (encoder-decoder) to generate text. I want to restrict the generation vocab of this model to a specific vocab. How can I do that?

I found in generate (model.generate) function that I can pass a parameter called force_words_ids where the model will be forced to generate "all" the tokens in this list. I am looking for something similar, but instead, to generate some of the list's tokens.

Minions
  • 5,104
  • 5
  • 50
  • 91

1 Answers1

2

As of transformers version 4.26.0, you can write a custom LogitsProcessor and pass it as the logits_processor keyword argument to model.generate. This should work in training too as long as you implement the __call__ method of your LogitsProcessor as working PyTorch code.

It is also recommended to pass renormalize_logits=True if you are generating longer output sequences with e.g. beam search and doing significant edits of the logits vector.