I try to fine-tune mBART-50 (paper, pre-trained model on Hugging Face) for machine translation in the transformers Python library. To test the fine-tuning, I am trying to simply teach mBART-50 a new word that I made up.
I use the following code. Over 95% of the code is from the Hugging Face documentation:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
print('Model loading started')
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="fr_XX", tgt_lang="en_XX")
print('Model loading done')
src_text = " billozarion "
tgt_text = " plorization "
model_inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
print('Fine-tuning started')
for i in range(1000):
#pass
model(**model_inputs, labels=labels) # forward pass
print('Fine-tuning ended')
# Testing whether the model learned the new word. Translate French to English
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer.src_lang = "fr_XX"
article_fr = src_text
encoded_fr = tokenizer(article_fr, return_tensors="pt")
generated_tokens = model.generate(**encoded_fr, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
print(translation)
However, the new word wasn't learned. The output is "billozarion" instead of "plorization". Why?
I'm strictly following the Hugging Face documentation, unless I missed something. The # forward pass
does make me concerned, as one would need a backward pass to update the gradients. Maybe this means that the documentation is incorrect, however I can't test that hypothesis as I don't know how to add the backward pass.
Environment that I used to run the code: Ubuntu 20.04.5 LTS with an NVIDIA A100 40GB GPU (I also tested with an NVIDIA T4 Tensor Core GPU) and CUDA 12.0 with the following conda environment:
conda create --name mbart-python39 python=3.9
conda activate mbart-python39
pip install transformers==4.28.1
pip install chardet==5.1.0
pip install sentencepiece==0.1.99
pip install protobuf==3.20