1

Relevant Code :

from transformers import (
    AdamW,
    MT5ForConditionalGeneration,
    AutoTokenizer,
    get_linear_schedule_with_warmup
)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-small', use_fast=True)
model=MT5ForConditionalGeneration.from_pretrained("working/result/", 
                                            return_dict=True)






 def generate(text):
   model.eval()
#    print(model)
#    input_ids = tokenizer.encode("WebNLG:{} </s>".format(text), 
#                                return_tensors="pt")  
   input_ids = tokenizer.batch_encode_plus(
                [text], max_length=512, pad_to_max_length=True, return_tensors="pt"
            ).to(device)
   source_ids = input_ids["input_ids"].squeeze()
   print(tokenizer.decode(source_ids))
   print(type(input_ids.input_ids))
   input_ids.input_ids.to(device)
   print(input)

   outputs = model.generate(input_ids.input_ids)
   print(outputs)
   print(outputs[0])
   return tokenizer.decode(outputs[0])

Calling above function

input_str = "Title: %s Category: %s" % ("10 Min Quick Food Recipe","Food")
input_str = "Title: %s Category: %s" % ("I am marathon runner and going to run 21km on 4th dec in Thane","Fitness")

print(input_str)
print(generate(input_str))

Output:

Title: I am marathon runner and going to run 21km on 4th dec in Thane Category: Fitness
Title: I am marathon runner and going to run 21km on 4th dec in Thane Category: Fitness</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
<class 'torch.Tensor'>
<bound method Kernel.raw_input of <ipykernel.ipkernel.IPythonKernel object at 0x7ff645eed970>>
tensor([[    0,   259,   266,   259,  3659,   390,   259,   262, 48580,   288,
           259,   262, 38226,  5401,   259,     1]], device='cuda:0')
tensor([    0,   259,   266,   259,  3659,   390,   259,   262, 48580,   288,
          259,   262, 38226,  5401,   259,     1], device='cuda:0')
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [30], line 5
      2 input_str = "Title: %s Category: %s" % ("I am marathon runner and going to run 21km on 4th dec in Thane","Fitness")
      4 print(input_str)
----> 5 print(generate(input_str))

Cell In [29], line 18, in generate(text)
     16 print(outputs)
     17 print(outputs[0])
---> 18 return tokenizer.decode(outputs[0])

File ~/T5/t5_venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3436, in PreTrainedTokenizerBase.decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3433 # Convert inputs to python lists
   3434 token_ids = to_py_obj(token_ids)
-> 3436 return self._decode(
   3437     token_ids=token_ids,
   3438     skip_special_tokens=skip_special_tokens,
   3439     clean_up_tokenization_spaces=clean_up_tokenization_spaces,
   3440     **kwargs,
   3441 )

File ~/T5/t5_venv/lib/python3.8/site-packages/transformers/tokenization_utils.py:949, in PreTrainedTokenizer._decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, spaces_between_special_tokens, **kwargs)
    947         current_sub_text.append(token)
    948 if current_sub_text:
--> 949     sub_texts.append(self.convert_tokens_to_string(current_sub_text))
    951 if spaces_between_special_tokens:
    952     text = " ".join(sub_texts)

File ~/T5/t5_venv/lib/python3.8/site-packages/transformers/models/byt5/tokenization_byt5.py:243, in ByT5Tokenizer.convert_tokens_to_string(self, tokens)
    241         tok_string = token.encode("utf-8")
    242     else:
--> 243         tok_string = bytes([ord(token)])
    244     bstring += tok_string
    245 string = bstring.decode("utf-8", errors="ignore")

ValueError: bytes must be in range(0, 256)

I tried to change the max_length param to 256 but can't seems to get it work. Any leads highly appreciated. Thanks in Advance.

iamabhaykmr
  • 1,803
  • 3
  • 24
  • 49

1 Answers1

0

Got it. I was doing a silly mistake. I was trying different pre-trained tokenizer and T5 models.

During training I had used google/mt5-base but during inference I used google/byt5-small which created this issue. Changed back to google/mt5-base to fix the issue. Now inference working fine.

iamabhaykmr
  • 1,803
  • 3
  • 24
  • 49