I tried to use the roberta
model from the torch hub, as:
import torch
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.')
roberta.predict('mnli', tokens).argmax() # 0: contradiction
but I am getting the following torch error:
│ 458 │
│ 459 │ @_copy_to_script_wrapper
│ 460 │ def __getitem__(self, key: str) -> Module:
│ ❱ 461 │ │ return self._modules[key]
│ 462 │
│ 463 │ def __setitem__(self, key: str, module: Module) -> None:
│ 464 │ │ self.add_module(key, module)
╰──────────────────────────────────────────────────────
KeyError: 'mnli'
How can I solve this?