I have an autoregressive language model in Pytorch that generates text, which is a collection of sentences, given one input:
output_text = ["sentence_1. sentence_2. sentence_3. sentence_4."]
Note that the output of the language model is in the form of logits (probability over the vocabulary), which can be converted to token IDS or strings.
Some of these sentences need to go into another model to get a loss that should affect only those sentences:
loss1 = model2("sentence_2")
loss2 = model2("sentence_4")
loss_total = loss1+loss2
What is the correct way to break/split the generated text from the first model without breaking differentiability? That is, so the corresponding text (from above) will look like a pytorch tensor of tensors (in order to then use some of them in the next model):
"[["sentence_1."]
["sentence_2."]
["sentence_3."]
["sentence_4."]]
For example, Python's split(".")
method will most likely break differentiability, but will allow me to take each individual sentence and insert it into the second model to get a loss.