2

I am using the T5 model found on Hugging Face for text summarization. How can I output the logits of the T5 model directly given a text input for generation purposes (not training)?

I want to generate the outputs token by token so that I can calculate the entropy of each output token, respectively. It does not seem like the .generate() method will work for this.

I effectively want to create my own generate function but I need to obtain the logits of the model to be able to do this.

muhleeshe
  • 61
  • 4

1 Answers1

0

You can use the forward function to get your logits, and apply argmax as such:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

input_ids = tokenizer("test here", padding="longest",
    max_length=128
    truncation=True,
    return_tensors="pt"
)

logits = model(**input_ids).logits

preds = F.softmax(logits, dim=-1).argmax(dim=-1)
y = tokenizer.batch_decode(sequences=preds, skip_special_tokens=True)

You may check the original source here: Forward outputs on multiple sequences is wrong

Edwin Cheong
  • 879
  • 2
  • 7
  • 12
  • 1
    I have been trying to do something similar to this. The problem is when I run the above code and anything similar, I get this error: "ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds." Which comes from calling model(input_ids). – muhleeshe Aug 11 '22 at 21:32
  • no ```model(**input_ids)``` the input_ids is a dictionary, ** means to match it with the params of the forward pass – Edwin Cheong Aug 12 '22 at 01:01
  • Ah I see. I took a look at the other post you mentioned and tried to get the forward pass function working. I was able to do so only when I use the generate() method to obtain labels to be part of the input of the forward pass function. I was wondering if there is a way to do a forward pass on T5 without any label generation. – muhleeshe Aug 16 '22 at 23:15
  • @muhleeshe How did you solve the issue? – Saeed Rahmani Jul 30 '23 at 16:55
  • @SaeedRahmani I went through the source for hugging face's T5 looking for a way to do this, I was not able to. I assume you want to use this for research? In any case, I would suggest trying to work with a different library if your research does not specifically require T5. I did not even consider this before but alternatively, you could fork the repository and make some changes that allow you to have direct access to the logits at each distinct token in sequence rather than at the end of the process (this should be fairly doable imo). – muhleeshe Jul 31 '23 at 19:18
  • @muhleeshe Thanks for the reply! Yes, I am using it for research and I have to use T5 on huggingface. I do know if forward pass is doing auto aggressively or not. Any ideas? What I need it the conditional logits (prob) when we have a label but generate tokes auto-regressively. – Saeed Rahmani Aug 01 '23 at 07:14