1

The Unbabel COMET is a scoring library for machine translation. By default, loading the model as per the README works:

from comet import download_model, load_from_checkpoint

model_path = download_model("Unbabel/wmt22-comet-da")
model = load_from_checkpoint(model_path)
data = [
    {
        "src": "Dem Feuer konnte Einhalt geboten werden",
        "mt": "The fire could be stopped",
        "ref": "They were able to control the fire."
    },
    {
        "src": "Schulen und Kindergärten wurden eröffnet.",
        "mt": "Schools and kindergartens were open",
        "ref": "Schools and kindergartens opened"
    }
]
model_output = model.predict(data, batch_size=8, gpus=1)
print(model_output)

The download_model is a wrapper over huggingface_hub.snapshot_download,

from huggingface_hub import snapshot_download
...

def download_model(
    model: str, 
    saving_directory: Union[str, Path, None] = None
) -> str:
    model_path = snapshot_download(repo_id=model, cache_dir=saving_directory)
    checkpoint_path = os.path.join(*[model_path, "checkpoints", "model.ckpt"])
    return checkpoint_path

That looks straightforward since the download_model returns the path to the model checkpoint path.

Then, under the hood, the model is a wrapper around PyTorch Lightning's Module class, https://github.com/Unbabel/COMET/blob/master/comet/models/base.py#L56

import pytorch_lightning as ptl
...

class CometModel(ptl.LightningModule, metaclass=abc.ABCMeta):
    """CometModel: Base class for all COMET models. ..."""
    def __init__(
        self,...
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.encoder = str2encoder[self.hparams.encoder_model].from_pretrained(
            self.hparams.pretrained_model, load_pretrained_weights
        )

While the comet.load_from_checkpoint looks like the LightningModule.load_from_checkpoint, it's not. It is some wrapper class around the LightningModule function, https://github.com/Unbabel/COMET/blob/master/comet/models/__init__.py

def load_from_checkpoint(checkpoint_path: str) -> CometModel:
    """Loads models from a checkpoint path.
    Args:
        checkpoint_path (str): Path to a model checkpoint.
    Return:
        COMET model.
    """
    checkpoint_path = Path(checkpoint_path)

    if not checkpoint_path.is_file():
        raise Exception(f"Invalid checkpoint path: {checkpoint_path}")
    
    parent_folder = checkpoint_path.parents[1] # .parent.parent
    hparams_file = parent_folder / "hparams.yaml"

    if hparams_file.is_file():
        with open(hparams_file) as yaml_file:
            hparams = yaml.load(yaml_file.read(), Loader=yaml.FullLoader)
        model_class = str2model[hparams["class_identifier"]]
        model = model_class.load_from_checkpoint(
            checkpoint_path, load_pretrained_weights=False
        )
        return model
    else:
        raise Exception(f"hparams.yaml file is missing from {parent_folder}!")

While the comet.download_model and comet.load_from_checkpoint functions works out of the box, the nested spaghetti wrapper obfuscate user knowledge of exactly where and which the model is stored and loaded.

Q: Is there a way to download and load the COMET models without using the default download_model and load_from_checkpoint?

The motivation to avoid that is to understand where exactly the model is being stored and loaded to prevent any malicious directory access as well as to contain the COMET function's access to specific directory. Thus a need to know and specify where to download the model and know where or which part of the model to load.

alvas
  • 115,346
  • 109
  • 446
  • 738
  • 1
    Hey, Liling! I'd add a tag `machine-translation-quality-evaluation`, but all my karma is on Linguistics SE, not SO. – Adam Bittlingmayer Mar 29 '23 at 20:03
  • 1
    Hmmm, but it's longer than 35 chars (max allowed for tags), maybe `machine-translation-evaluations` or `machine-translation-metrics`? – alvas Mar 30 '23 at 09:31
  • 1
    Good point. I'd probably go for `machine-translation-evaluation` and `translation-evaluation` or `translation-quality-evaluation`. What do you think? Unlike confidence scoring, most of these approaches don't have requirements for how the target was generated. – Adam Bittlingmayer Mar 30 '23 at 10:52

1 Answers1

1

You can try the following:

Default model

import os

from huggingface_hub import snapshot_download

from comet.models.regression.regression_metric import RegressionMetric

model_path = snapshot_download(repo_id="Unbabel/wmt22-comet-da", cache_dir=os.path.abspath(os.path.dirname('.')))
model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"

# Calling the PyTorch Lightning's ModelIO function
default_model = RegressionMetric.load_from_checkpoint(model_checkpoint_path)

data = [
    {
        "src": "Dem Feuer konnte Einhalt geboten werden",
        "mt": "The fire could be stopped",
        "ref": "They were able to control the fire."
    },
    {
        "src": "Schulen und Kindergärten wurden eröffnet.",
        "mt": "Schools and kindergartens were open",
        "ref": "Schools and kindergartens opened"
    }
]

default_model.predict(data, batch_size=8, gpus=1)

[out]:

Prediction([('scores', [0.8385581374168396, 0.9717257618904114]), ('system_score', 0.9051419496536255)])

Referenceless (old model)

import os

from huggingface_hub import snapshot_download

from comet.models.regression.referenceless import ReferencelessRegression

model_path = snapshot_download(repo_id="Unbabel/wmt20-comet-qe-da", 
                               cache_dir=os.path.abspath(os.path.dirname('.')))
model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"

referenceless_model = ReferencelessRegression.load_from_checkpoint(model_checkpoint_path)

data = [
    {
        "src": "Dem Feuer konnte Einhalt geboten werden",
        "mt": "The fire could be stopped"
    },
    {
        "src": "Schulen und Kindergärten wurden eröffnet.",
        "mt": "Schools and kindergartens were open"
    }
]

referenceless_model.predict(data, batch_size=8, gpus=1)

[out]:

Prediction([('scores', [0.5127583742141724, 0.6407940983772278]),
            ('system_score', 0.5767762362957001)])
alvas
  • 115,346
  • 109
  • 446
  • 738