0

I'm trying to use Sharpness-Aware Minimization (SAM) optimizer in my code, using the already built Pytorch code from here. Then, I would also like to use gradient accumulation, but I have no idea how to make this works properly. Using the proposed idea in one of the closed issue for mixed-precision:

def train(
    args, model, device, train_loader, optimizer, first_step_scaler, second_step_scaler, epoch
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        enable_running_stats(model)
        # First forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        first_step_scaler.scale(loss).backward()

        # We unscale manually for two reasons: (1) SAM's first-step adds the gradient
        # to weights directly. So gradient must be unscaled; (2) unscale_ checks if any
        # gradient is inf and updates optimizer_state["found_inf_per_device"] accordingly.
        # We use optimizer_state["found_inf_per_device"] to decide whether to apply
        # SAM's first-step or not.
        first_step_scaler.unscale_(optimizer)

        optimizer_state = first_step_scaler._per_optimizer_states[id(optimizer)]

        # Check if any gradients are inf/nan
        inf_grad_cnt = sum(v.item() for v in optimizer_state["found_inf_per_device"].values())

        if inf_grad_cnt == 0:
            # if valid graident, apply sam_first_step
            optimizer.first_step(zero_grad=True, mixed_precision=True)
            sam_first_step_applied = True
        else:
            # if invalid graident, skip sam and revert to single optimization step
            optimizer.zero_grad()
            sam_first_step_applied = False

        # Update the scaler with no impact on the model (weights or gradient). This update step
        # resets the optimizer_state["found_inf_per_device"]. So, it is applied after computing
        # inf_grad_cnt. Note that zero_grad() has no impact on the update() operation,
        # because update() leverage optimizer_state["found_inf_per_device"]
        first_step_scaler.update()

        disable_running_stats(model)
        # Second forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        second_step_scaler.scale(loss).backward()

        if sam_first_step_applied:
            # If sam_first_step was applied, apply the 2nd step
            optimizer.second_step(mixed_precision=True)

        second_step_scaler.step(optimizer)

I tried something like this:

def train(
    args, model, device, train_loader, optimizer, first_step_scaler, second_step_scaler, epoch, gradient_acc=2
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        enable_running_stats(model)
        # First forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        
        loss = loss / gradient_acc
        first_step_scaler.scale(loss).backward()

        # We unscale manually for two reasons: (1) SAM's first-step adds the gradient
        # to weights directly. So gradient must be unscaled; (2) unscale_ checks if any
        # gradient is inf and updates optimizer_state["found_inf_per_device"] accordingly.
        # We use optimizer_state["found_inf_per_device"] to decide whether to apply
        # SAM's first-step or not.
        first_step_scaler.unscale_(optimizer)

        optimizer_state = first_step_scaler._per_optimizer_states[id(optimizer)]

        # Check if any gradients are inf/nan
        inf_grad_cnt = sum(v.item() for v in optimizer_state["found_inf_per_device"].values())

        if inf_grad_cnt == 0:
            # if valid graident, apply sam_first_step
            optimizer.first_step(zero_grad=True, mixed_precision=True)
            sam_first_step_applied = True
        else:
            # if invalid graident, skip sam and revert to single optimization step
            optimizer.zero_grad()
            sam_first_step_applied = False

        # Update the scaler with no impact on the model (weights or gradient). This update step
        # resets the optimizer_state["found_inf_per_device"]. So, it is applied after computing
        # inf_grad_cnt. Note that zero_grad() has no impact on the update() operation,
        # because update() leverage optimizer_state["found_inf_per_device"]
        first_step_scaler.update()

        disable_running_stats(model)
        # Second forward step
        with autocast():
            output = model(data)
            loss = F.nll_loss(output, target)
        loss = loss / gradient_acc
        second_step_scaler.scale(loss).backward()

        if sam_first_step_applied:
            # If sam_first_step was applied, apply the 2nd step
            optimizer.second_step(mixed_precision=True)

        if not (batch_idx + 1) % gradient_acc != 0:
            second_step_scaler.step(optimizer)
            second_step_scaler.update()
            optimizer.zero_grad()

But I noticed this makes my loss increasing rather than decreasing, anyone have any idea how to improvise this?

core_not_dumped
  • 759
  • 2
  • 22
Gregor Isack
  • 1,111
  • 12
  • 25

0 Answers0