For my use case, I need to use the model.forward() instead of the model.generate() method i.e instead of the below code
outs = model.model.generate(input_ids=batch['source_ids'],
attention_mask=batch['source_mask'],
output_scores=True,
max_length=model.model_arguments.max_output_seq_length)
preds_cleaned = [model.tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) for ids in outs]
I need to use
model_outputs = model.model(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
labels=lm_labels.to(device),
decoder_attention_mask=batch['target_mask']
)
logits = model_outputs.logits
softmax_logits = m(logits)
max_logits = torch.max(softmax_logits, dim=2)
decoding these logits gives unprocessed text that has many issues like repetition of words at the end etc. What do I need to do to get the same result as model.generate() ?