1

I am trying to play around with the Fairseq machine translation model using

en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de',
                       checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
                       tokenizer='moses', bpe='fastbpe')

When i use en2de.generate(....), i wanna know what are the return values of this function.

This function is defined in the hub_utils.py file of fairseq model

I tried debugging the code, but didnt get anywhere. I need a better understanding of its return types.

alvas
  • 115,346
  • 109
  • 446
  • 738
  • "This function is defined in the hub_utils.py file of fairseq model"... So post this code. If your question relies on SO users tracking down external code modules on other websites to try to address your question, you will not get a quality answer. – DerekG May 30 '23 at 13:53
  • Welcome to Stackoverflow! It'll be useful for you to add where the code comes from, esp. when it's not such a well-known library to most developers. Also, what errors are you seeing? Why are you not able to print the `type(...)` to find the return type of the function? – alvas May 31 '23 at 02:31

1 Answers1

2

Most probably the code snippet you're looking at came from https://github.com/facebookresearch/fairseq/blob/main/examples/wmt19/README.md

The example code there looks like this:

import torch

# English to German translation
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de',
    checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
    tokenizer='moses', bpe='fastbpe')

But most probably you'll meet some environmental setup issues because fairseq isn't easily useable "off-the-shelf". So, you'll have to do something like this:

! pip install -U fastBPE sacremoses 
! pip install -U hydra-core omegaconf bitarray

! git clone https://github.com/pytorch/fairseq && cd fairseq && pip install --editable ./

After setting up the environment, now you can try this again:

import torch

# English to German translation
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de',
    checkpoint_file='model1.pt',
    tokenizer='moses', bpe='fastbpe')

type(en2de)

[out]:

fairseq.hub_utils.GeneratorHubInterface

If we do some code digging, it points to https://github.com/facebookresearch/fairseq/blob/main/fairseq/hub_utils.py#L97

class GeneratorHubInterface(nn.Module):
    """
    PyTorch Hub interface for generating sequences from a pre-trained
    translation or language model.
    """

And if we look at the translate() function, it goes to https://github.com/facebookresearch/fairseq/blob/main/fairseq/hub_utils.py#LL133C1-L145C76

    def translate(
        self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
    ) -> List[str]:
        return self.sample(sentences, beam, verbose, **kwargs)

    def sample(
        self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
    ) -> List[str]:
        if isinstance(sentences, str):
            return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
        tokenized_sentences = [self.encode(sentence) for sentence in sentences]
        batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
        return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]

So .translate() returns a list of strings

And if we dig deeper into the rabbit hole, we see the .generate() function from https://github.com/facebookresearch/fairseq/blob/main/fairseq/hub_utils.py#L170 which returns

def generate(
        self,
        tokenized_sentences: List[torch.LongTensor],
        beam: int = 5,
        verbose: bool = False,
        skip_invalid_size_inputs=False,
        inference_step_args=None,
        prefix_allowed_tokens_fn=None,
        **kwargs
    ) -> List[List[Dict[str, torch.Tensor]]]:

And if you use the model with .generate(),

tokenized_sentences = en2de.encode("Machine learning is great!")
en2de.generate([tokenized_sentences])

[out]:

[[{'tokens': tensor([21259,    99,  4125, 15336,    34,  5013, 19663,   111,     2]),
   'score': tensor(-0.2017),
   'attention': tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0026, 0.0049, 0.0067],
           [0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0021, 0.0029, 0.0034],
           [0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0031, 0.0072, 0.0059],
           [0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0093, 0.0214, 0.0088],
           [0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0274, 0.0126, 0.0065],
           [0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.0146, 0.1026, 0.0346],
           [0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.9409, 0.8484, 0.9341]]),
   'alignment': tensor([]),
   'positional_scores': tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -0.5898, -0.0818, -0.1108,
           -0.1069])},
  {'tokens': tensor([21259,    99,  4125, 15336,    34, 19503,   111,     2]),
   'score': tensor(-0.3501),
   'attention': tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0048, 0.0073],
           [0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0027, 0.0037],
           [0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0070, 0.0063],
           [0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0217, 0.0097],
           [0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0129, 0.0078],
           [0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.1076, 0.0373],
           [0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.8431, 0.9280]]),
   'alignment': tensor([]),
   'positional_scores': tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -1.6566, -0.1113, -0.1072])},
  {'tokens': tensor([ 5725,   372,  8984,  3845,    34,  5013, 19663,   111,     2]),
   'score': tensor(-0.4066),
   'attention': tensor([[0.2876, 0.0278, 0.0040, 0.0030, 0.0192, 0.0150, 0.0032, 0.0075, 0.0083],
           [0.1374, 0.0755, 0.0019, 0.0379, 0.0087, 0.0062, 0.0027, 0.0044, 0.0043],
           [0.0817, 0.0269, 0.4516, 0.0801, 0.0227, 0.0120, 0.0038, 0.0084, 0.0065],
           [0.0684, 0.0034, 0.0067, 0.0091, 0.1939, 0.1039, 0.0097, 0.0224, 0.0099],
           [0.0862, 0.0031, 0.0040, 0.0030, 0.1022, 0.2868, 0.0296, 0.0135, 0.0073],
           [0.0415, 0.0058, 0.0054, 0.0066, 0.0373, 0.0400, 0.0146, 0.1016, 0.0351],
           [0.2972, 0.8574, 0.5264, 0.8603, 0.6160, 0.5361, 0.9364, 0.8422, 0.9287]]),
   'alignment': tensor([]),
   'positional_scores': tensor([-2.0029, -0.3431, -0.1785, -0.0286, -0.1586, -0.6527, -0.0782, -0.1101,
           -0.1071])},
  {'tokens': tensor([21259,    99,  4125, 15336,    34,  8404,   111,     2]),
   'score': tensor(-0.5465),
   'attention': tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0047, 0.0074],
           [0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0026, 0.0037],
           [0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0071, 0.0064],
           [0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0221, 0.0095],
           [0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0125, 0.0077],
           [0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.1046, 0.0372],
           [0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.8464, 0.9282]]),
   'alignment': tensor([]),
   'positional_scores': tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -3.2290, -0.1100, -0.1075])},
  {'tokens': tensor([ 9467,  5293,    34,  5013, 19663,   111,     2]),
   'score': tensor(-0.5483),
   'attention': tensor([[0.2876, 0.0109, 0.0157, 0.0154, 0.0032, 0.0069, 0.0093],
           [0.1374, 0.0110, 0.0081, 0.0065, 0.0027, 0.0039, 0.0047],
           [0.0817, 0.2288, 0.0219, 0.0131, 0.0034, 0.0076, 0.0069],
           [0.0684, 0.0045, 0.1818, 0.0989, 0.0097, 0.0224, 0.0100],
           [0.0862, 0.0037, 0.0979, 0.2854, 0.0276, 0.0135, 0.0076],
           [0.0415, 0.0074, 0.0343, 0.0404, 0.0146, 0.1005, 0.0363],
           [0.2972, 0.7337, 0.6403, 0.5402, 0.9388, 0.8452, 0.9253]]),
   'alignment': tensor([]),
   'positional_scores': tensor([-2.1557, -0.5372, -0.1502, -0.6968, -0.0819, -0.1092, -0.1072])}]]

.generate() returns a list of list of dict, where keys are names and values are tensor

The outer most list is the sentences' result. If you have one sentence, the result for the sentence is:

tokenized_sentences = [en2de.encode("Machine learning is great!")]
results = en2de.generate(tokenized_sentences)

translation_sent1 = results[0]

len(translation_sent1)

[out]:

5

You'll see that each sentence has 5 translation results. This is because the beam size is set to 5 by default. Each dictionary in the inner list corresponds to the translations from each beam.

tokenized_sentences = [en2de.encode("Machine learning is great!")]
results = en2de.generate(tokenized_sentences, beam=2)

translation_sent1 = results[0]

len(translation_sent1)

[out]:

2

And to get the best translation:

tokenized_sentences = [en2de.encode("Machine learning is great!")]
results = en2de.generate(tokenized_sentences, beam=2)

translation_sent1 = results[0] # 2 translations from 2 beams for the 1st sentence.

best_translation = translation_sent1[0] # Best 1 translation out of the 2 beams.
best_translation

[out]:

{'tokens': tensor([21259,    99,  4125, 15336,    34,  5013, 19663,   111,     2]),
 'score': tensor(-0.2017),
 'attention': tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0026, 0.0049, 0.0067],
         [0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0021, 0.0029, 0.0034],
         [0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0031, 0.0072, 0.0059],
         [0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0093, 0.0214, 0.0088],
         [0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0274, 0.0126, 0.0065],
         [0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.0146, 0.1026, 0.0346],
         [0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.9409, 0.8484, 0.9341]]),
 'alignment': tensor([]),
 'positional_scores': tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -0.5898, -0.0818, -0.1108,
         -0.1069])}

And to get the string representation, we fetch the tokens and decode them:

en2de.decode(best_translation['tokens'])

[out]:

Maschinelles Lernen ist großartig!

Here's the working code for the above examples, https://www.kaggle.com/alvations/how-to-use-fairseq-wmt19-models

alvas
  • 115,346
  • 109
  • 446
  • 738