The Unbabel COMET is a scoring library for machine translation. By default, loading the model as per the README works:
from comet import download_model, load_from_checkpoint
model_path = download_model("Unbabel/wmt22-comet-da")
model = load_from_checkpoint(model_path)
data = [
{
"src": "Dem Feuer konnte Einhalt geboten werden",
"mt": "The fire could be stopped",
"ref": "They were able to control the fire."
},
{
"src": "Schulen und Kindergärten wurden eröffnet.",
"mt": "Schools and kindergartens were open",
"ref": "Schools and kindergartens opened"
}
]
model_output = model.predict(data, batch_size=8, gpus=1)
print(model_output)
The download_model
is a wrapper over huggingface_hub.snapshot_download
,
from huggingface_hub import snapshot_download
...
def download_model(
model: str,
saving_directory: Union[str, Path, None] = None
) -> str:
model_path = snapshot_download(repo_id=model, cache_dir=saving_directory)
checkpoint_path = os.path.join(*[model_path, "checkpoints", "model.ckpt"])
return checkpoint_path
That looks straightforward since the download_model
returns the path to the model checkpoint path.
Then, under the hood, the model is a wrapper around PyTorch Lightning's Module class, https://github.com/Unbabel/COMET/blob/master/comet/models/base.py#L56
import pytorch_lightning as ptl
...
class CometModel(ptl.LightningModule, metaclass=abc.ABCMeta):
"""CometModel: Base class for all COMET models. ..."""
def __init__(
self,...
) -> None:
super().__init__()
self.save_hyperparameters()
self.encoder = str2encoder[self.hparams.encoder_model].from_pretrained(
self.hparams.pretrained_model, load_pretrained_weights
)
While the comet.load_from_checkpoint
looks like the LightningModule.load_from_checkpoint
, it's not. It is some wrapper class around the LightningModule
function, https://github.com/Unbabel/COMET/blob/master/comet/models/__init__.py
def load_from_checkpoint(checkpoint_path: str) -> CometModel:
"""Loads models from a checkpoint path.
Args:
checkpoint_path (str): Path to a model checkpoint.
Return:
COMET model.
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.is_file():
raise Exception(f"Invalid checkpoint path: {checkpoint_path}")
parent_folder = checkpoint_path.parents[1] # .parent.parent
hparams_file = parent_folder / "hparams.yaml"
if hparams_file.is_file():
with open(hparams_file) as yaml_file:
hparams = yaml.load(yaml_file.read(), Loader=yaml.FullLoader)
model_class = str2model[hparams["class_identifier"]]
model = model_class.load_from_checkpoint(
checkpoint_path, load_pretrained_weights=False
)
return model
else:
raise Exception(f"hparams.yaml file is missing from {parent_folder}!")
While the comet.download_model
and comet.load_from_checkpoint
functions works out of the box, the nested spaghetti wrapper obfuscate user knowledge of exactly where and which the model is stored and loaded.
Q: Is there a way to download and load the COMET models without using the default download_model
and load_from_checkpoint
?
The motivation to avoid that is to understand where exactly the model is being stored and loaded to prevent any malicious directory access as well as to contain the COMET function's access to specific directory. Thus a need to know and specify where to download the model and know where or which part of the model to load.