3

I'm trying to fine tuning using bert model. I'm using pre trained bert model and pytorch.

The problem is that the result of GPU and the result of TPU are slightly different. (accuracy is different about -2% ~ 2%)

I used same dataset, same seed. And the difference of setting is like this:

device: gpu(RTX2080Ti * 4) vs tpu 2 (1core)
pytorch version: torch 1.5(gpu) vs torch 1.10 & torch_xla 1.10(tpu)

# In the code of tpu setting, I modified some lines.

# 1. set the device  
self.device = xm.xla_device()

# 2. optimizer
xm.optimizer_step(self.optimizer, barrier=True)

I set the pytorch version of TPU setting to 1.10, because google vm and tpu do not offer torch 1.15 version.
And the full code of training class is like this:

from logging import Logger

import torch
from sklearn.metrics import accuracy_score, classification_report
from torch import nn
from torch.optim.adamw import AdamW
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup

from tasks.cola.config import TrainConfig
from tasks.cola.model import COLAModel


class Trainer:
    def __init__(
        self,
        config: TrainConfig,
        model: COLAModel,
        train_data_loader: DataLoader,
        dev_data_loader: DataLoader,
        test_data_loader: DataLoader,
        logger: Logger,
        summary_writer: SummaryWriter,
    ):
        self.config = config

        if config.use_tpu == True:
            import torch_xla
            import torch_xla.core.xla_model as xm # for using tpu
            self.device = xm.xla_device()
            self.model = model
            print('TPU running...')
        elif config.use_tpu == False:
        # multi gpu(3)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            if (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
                print('Multi GPU({}) activate'.format(torch.cuda.device_count()))
                self.model = nn.DataParallel(model, device_ids=[0,1,2,3])
            else:
                self.model = model

        self.model.to(self.device)
     
        self.train_data_loader = train_data_loader
        self.dev_data_loader = dev_data_loader
        self.test_data_loader = test_data_loader
        self.logger = logger
        self.summary_writer = summary_writer

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = AdamW(model.parameters(), lr=config.learning_rate)

        self.steps_per_epoch = len(train_data_loader)
        self.total_steps = self.steps_per_epoch * config.num_epochs
        self.warmup_steps = config.warmup_step_ratio * self.total_steps

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.total_steps
        )
        self.global_step = 0

    def train(self):
        # train
        self.logger.info("========== train ==========")
        self.logger.info(f"device                : {self.device}")
        self.logger.info(f"dataset length/ train : {len(self.train_data_loader.dataset)}")
        self.logger.info(f"dataset length/ dev   : {len(self.dev_data_loader.dataset)}")
        self.logger.info(f"dataset length/ test  : {len(self.test_data_loader.dataset)}")
        self.logger.info(f"batch size            : {self.config.batch_size}")
        self.logger.info(f"learning rate         : {self.config.learning_rate}")
        self.logger.info(f"dropout prob          : {self.config.dropout_prob}")
        self.logger.info(f"total epoch           : {self.config.num_epochs}")
        self.logger.info(f"steps per epoch       : {self.steps_per_epoch}")
        self.logger.info(f"total steps           : {self.total_steps}")
        self.logger.info(f"warmup steps          : {self.warmup_steps}\n")

        for epoch in range(self.config.num_epochs):
            running_loss = 0.0
            train_targets = []
            train_predictions = []

            for step, data in enumerate(tqdm(self.train_data_loader)):
                self.model.train()

                self.global_step += 1

                input_token_ids = data[0].to(self.device)
                attention_mask = data[1].to(self.device)
                token_type_ids = data[2].to(self.device)
                labels = data[3].to(self.device)

                loss, outputs = self._train_step(input_token_ids, attention_mask, token_type_ids, labels)

                running_loss += loss
                train_targets.extend(labels.tolist())
                train_predictions.extend(outputs.argmax(-1).tolist())

                if (step + 1) % self.config.logging_interval == 0:
                    train_loss = running_loss / self.config.logging_interval
                    train_acc = accuracy_score(train_targets, train_predictions)
                    self.logger.info(f"Epoch {epoch}, Step {step + 1}\t| Loss {train_loss:.4f}  Acc {train_acc:.4f}")

                    self.summary_writer.add_scalar("cola/train/loss", train_loss, self.global_step)
                    self.summary_writer.add_scalar("cola/train/accuracy", train_acc, self.global_step)

                    running_loss = 0.0
                    train_targets = []
                    train_predictions = []

            # dev every epoch
            dev_loss, dev_targets, dev_predictions = self._validation(self.dev_data_loader)
            dev_report = classification_report(dev_targets, dev_predictions, digits=4)
            self.logger.info(f"######### DEV REPORT #EP{epoch} #########")
            self.logger.info(f"Loss {dev_loss:.4f}")
            self.logger.info(f"\n{dev_report}")

            dev_acc = accuracy_score(dev_targets, dev_predictions)
            self.summary_writer.add_scalar("cola/dev/loss", dev_loss, self.global_step)
            self.summary_writer.add_scalar("cola/dev/accuracy", dev_acc, self.global_step)

            # test every epoch
            test_loss, test_targets, test_predictions = self._validation(self.test_data_loader)
            test_report = classification_report(test_targets, test_predictions, digits=4)
            self.logger.info(f"######### TEST REPORT #EP{epoch} #########")
            self.logger.info(f"Loss {test_loss:.4f}")
            self.logger.info(f"\n{test_report}")

            test_acc = accuracy_score(test_targets, test_predictions)
            self.summary_writer.add_scalar("cola/test/loss", test_loss, self.global_step)
            self.summary_writer.add_scalar("cola/test/accuracy", test_acc, self.global_step)

            # output_path = os.path.join(self.config.checkpoint_dir, f"model-epoch-{epoch}.pth")
            # torch.save(self.model.state_dict(), output_path)
            # self.logger.info(f"MODEL IS SAVED AT {output_path}\n")

    def _train_step(self, input_token_ids, attention_mask, token_type_ids, labels):
        self.optimizer.zero_grad()

        outputs = self.model(input_token_ids, attention_mask, token_type_ids)

        loss = self.criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        if self.config.use_tpu == True:
            # optimizer for TPU (Note: Cloud TPU-specific code!)
            import torch_xla.core.xla_model as xm # for using tpu
            xm.optimizer_step(self.optimizer, barrier=True) 
        else:
            self.optimizer.step()

        #self.optimizer.step()
        self.scheduler.step()

        return loss.item(), outputs

    def _validation(self, data_loader):
        self.model.eval()

        running_loss = 0.0
        targets = []
        predictions = []

        with torch.no_grad():
            for data in data_loader:
                input_token_ids = data[0].to(self.device)
                attention_mask = data[1].to(self.device)
                token_type_ids = data[2].to(self.device)
                labels = data[3].to(self.device)

                outputs = self.model(input_token_ids, attention_mask, token_type_ids)

                loss = self.criterion(outputs, labels)

                running_loss += loss.item()
                targets.extend(labels.tolist())
                predictions.extend(outputs.argmax(-1).tolist())

        assert len(targets) == len(predictions)

        mean_loss = running_loss / len(data_loader)

        return mean_loss, targets, predictions

Is there any way to fix this problem?

Kyle
  • 31
  • 2

1 Answers1

1

Considering you implemented the loss reduction correctly, a common error source when compareing the metrics is to have the same batchsize per worker on GPU and TPU, while the count of workers differs. If you have a batch size per worker of lets say 128, on a 4 GPU system you would have an effective batchsize of 512, while on an 8-core TPU would have an effective batch size of 1024.

Sascha Kirch
  • 466
  • 2
  • 3
  • 19