1

I have an autoregressive language model in Pytorch that generates text, which is a collection of sentences, given one input:

output_text = ["sentence_1. sentence_2. sentence_3. sentence_4."]

Note that the output of the language model is in the form of logits (probability over the vocabulary), which can be converted to token IDS or strings.

Some of these sentences need to go into another model to get a loss that should affect only those sentences:

loss1 = model2("sentence_2")
loss2 = model2("sentence_4")
loss_total = loss1+loss2

What is the correct way to break/split the generated text from the first model without breaking differentiability? That is, so the corresponding text (from above) will look like a pytorch tensor of tensors (in order to then use some of them in the next model):

"[["sentence_1."]
["sentence_2."] 
["sentence_3."]
["sentence_4."]]

For example, Python's split(".") method will most likely break differentiability, but will allow me to take each individual sentence and insert it into the second model to get a loss.

Penguin
  • 1,923
  • 3
  • 21
  • 51
  • 1
    Does the model actually give out a string or some encoding of it (I am not very familiar with language models)? AFAIK you can't have a tensor of type string in PyTorch. – GoodDeeds May 24 '22 at 21:57
  • 1
    Good question, I clarified my question! TL;DR: the output is in the form of logits which can be converted to token IDS or strings – Penguin May 24 '22 at 22:26
  • 1
    I am confused. Are you trying to make "extracting a substring" differentiable? How can you differentiate a string anyway? – ihdv May 25 '22 at 02:41
  • @ihdv Not quite. The strings are just a way to view the output, which is a tensor of tensors (logits) – Penguin May 25 '22 at 04:25
  • @Penguin, what is the sequence encoding you have chosen? – Ivan May 25 '22 at 07:34

1 Answers1

1

Okay solved it. Posting answer for completion.

Since the output is in the form of logits, I can take the argmax to get the indices of each token. This should allow me to know where each period is (to know where the end of the sentence is). I can then split the sentences in the following way to maintain the gradients:

sentences_list = []
r = torch.rand(50) #imagine that this is the output logits (though instead of a tensor of values it will be a tensor of tensors)
period_indices = [10,30,49]
sentences_list.append(r[0:10])
sentences_list.append(r[10:30])
sentences_list.append(r[30:])

Now each element in sentences_list is a sentence, that I can send to another model to get a loss

Penguin
  • 1,923
  • 3
  • 21
  • 51