2

I am fine-tuning a HuggingFace transformer model (PyTorch version), using the HF Seq2SeqTrainingArguments & Seq2SeqTrainer, and I want to display in Tensorboard the train and validation losses (in the same chart).

As far as I understand in order to plot the two losses together I need to use the SummaryWriter. The HF Callbacks documenation describes a TensorBoardCallback function that can receive a tb_writer argument:

https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/callback#transformers.integrations.TensorBoardCallback

However, I cannot figure out what is the right way to use it, if it is even supposed to be used with the Trainer API.

My code looks something like this:

args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy='epoch',
    learning_rate= 1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
    report_to='tensorboard',
    push_to_hub=False,  
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_val_data,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

I would assume I should include the callback to TensorBoard in the trainer, e.g.,

callbacks = [TensorBoardCallback(tb_writer=tb_writer)]

but I cannot find a comprehensive example of how to use/what to import to use it.

I also found this feature request on GitHub,

https://github.com/huggingface/transformers/pull/4020

but no example of use, so I am confused...

Any insight will be appreciated

James Hirschorn
  • 7,032
  • 5
  • 45
  • 53
anna-kay
  • 41
  • 1
  • 8

2 Answers2

3

The only way I know of to plot two values on the same TensorBoard graph is to use two separate SummaryWriters with the same root directory. For example, the logging directories might be: log_dir/train and log_dir/eval.

This approach is used in this answer but for TensorFlow instead of pytorch.

In order to do this with the Trainer API a custom callback is needed that takes two SummaryWriters. Here is the code for my custom callback CombinedTensorBoardCallback, that I made by modifying the code for TensorBoardCallback:

import os
from transformers.integrations import TrainerCallback, is_tensorboard_available

def custom_rewrite_logs(d, mode):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
    for k, v in d.items():
        if mode == 'eval' and k.startswith(eval_prefix):
            if k[eval_prefix_len:] == 'loss':
                new_d["combined/" + k[eval_prefix_len:]] = v
        elif mode == 'test' and k.startswith(test_prefix):
            if k[test_prefix_len:] == 'loss':
                new_d["combined/" + k[test_prefix_len:]] = v
        elif mode == 'train':
            if k == 'loss':
                new_d["combined/" + k] = v
    return new_d


class CombinedTensorBoardCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
    Args:
        tb_writer (`SummaryWriter`, *optional*):
            The writer to use. Will instantiate one if not set.
    """

    def __init__(self, tb_writers=None):
        has_tensorboard = is_tensorboard_available()
        if not has_tensorboard:
            raise RuntimeError(
                "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
                " install tensorboardX."
            )
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
        else:
            self._SummaryWriter = None
        self.tb_writers = tb_writers

    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
        if self._SummaryWriter is not None:
            self.tb_writers = dict(train=self._SummaryWriter(log_dir=os.path.join(log_dir, 'train')),
                                   eval=self._SummaryWriter(log_dir=os.path.join(log_dir, 'eval')))

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        if self.tb_writers is None:
            self._init_summary_writer(args, log_dir)

        for k, tbw in self.tb_writers.items():
            tbw.add_text("args", args.to_json_string())
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    tbw.add_text("model_config", model_config_json)
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(tbw, "add_hparams"):
                tbw.add_hparams(args.to_sanitized_dict(), metric_dict={})

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not state.is_world_process_zero:
            return

        if self.tb_writers is None:
            self._init_summary_writer(args)

        for tbk, tbw in self.tb_writers.items():
            logs_new = custom_rewrite_logs(logs, mode=tbk)
            for k, v in logs_new.items():
                if isinstance(v, (int, float)):
                    tbw.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute."
                    )
            tbw.flush()

    def on_train_end(self, args, state, control, **kwargs):
        for tbw in self.tb_writers.values():
            tbw.close()
        self.tb_writers = None

If you want to combine train and eval for other metrics besides the loss then custom_rewrite_logs should be modified accordingly.

As usual, the callback goes in the Trainer constructor. In my test example it was:

trainer = Trainer(
    model=rnn,
    args=train_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[CombinedTensorBoardCallback]
)

Also you might want to remove the default TensorBoardCallback or else in addition to the combined loss graph, the training loss and validation loss will both appear separately as it does by default.

trainer.remove_callback(TensorBoardCallback)

Here is the resulting TensorBoard view:

enter image description here

James Hirschorn
  • 7,032
  • 5
  • 45
  • 53
  • Thnx for the answer, I have no trouble outputting events for Tensorboard, I want to output train and validation loss on the ***same*** plot (Tensorboard by default produces, among other, 2 separate plots, one for the train loss and one for validation). I assumed that to do this I will have to tweak the tb_writer and pass it to as an argument to a TensorBoardCallback. Do you have any idea on this? – anna-kay Oct 29 '22 at 15:05
  • Oh, I did not read your question carefully. I have edit the title of this question to make it clear what you asking about. I hope my new answer helps. Btw, I do not know if it is even possible by tweaking a single `tb_writer`, and instead I used two. – James Hirschorn Nov 15 '22 at 22:44
0

it's pretty simple. You mention it in the "Seq2SeqTrainingArguments". There is no need to define it explicitly in the "Seq2SeqTrainer" function.

model_arguments = Seq2SeqTrainingArguments(output_dir= "./best_model/",
                                        num_train_epochs = EPOCHS, 
                                        overwrite_output_dir= True, 
                                        do_train= True, 
                                        do_eval= True, 
                                        do_predict= True, 
                                        auto_find_batch_size= True, 
                                        evaluation_strategy = 'epoch',
                                        warmup_steps = 10000, 
                                        logging_dir = "./log_files/", 
                                        disable_tqdm = False, 
                                        load_best_model_at_end = True, 
                                        save_strategy= 'epoch', 
                                        save_total_limit = 1, 
                                        per_device_eval_batch_size= BATCH_SIZE, 
                                        per_device_train_batch_size= BATCH_SIZE, 
                                        predict_with_generate=True, 
                                        report_to='wandb',
                                        run_name="rober_based_encoder_decoder_text_summarisation"
                                        
                                        )

meanwhile you can have other callbacks:

early_stopping = EarlyStoppingCallback(early_stopping_patience= 5, 
                                    early_stopping_threshold= 0.001)

Then you pass the arguments and callbacks as the list through the trainer arguments:

trainer = Seq2SeqTrainer(model = model, 
                        compute_metrics= compute_metrics,
                        args= model_arguments, 
                        train_dataset= Train, 
                        eval_dataset= Val, 
                        tokenizer=tokenizer, 
                        callbacks= [early_stopping, ]
                        )

Train the model. Make sure you log into the wandb before training

trainer.train()
Junaid
  • 1
  • 1