I'm trying to use Sharpness-Aware Minimization (SAM) optimizer in my code, using the already built Pytorch code from here. Then, I would also like to use gradient accumulation, but I have no idea how to make this works properly. Using the proposed idea in one of the closed issue for mixed-precision:
def train(
args, model, device, train_loader, optimizer, first_step_scaler, second_step_scaler, epoch
):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
enable_running_stats(model)
# First forward step
with autocast():
output = model(data)
loss = F.nll_loss(output, target)
first_step_scaler.scale(loss).backward()
# We unscale manually for two reasons: (1) SAM's first-step adds the gradient
# to weights directly. So gradient must be unscaled; (2) unscale_ checks if any
# gradient is inf and updates optimizer_state["found_inf_per_device"] accordingly.
# We use optimizer_state["found_inf_per_device"] to decide whether to apply
# SAM's first-step or not.
first_step_scaler.unscale_(optimizer)
optimizer_state = first_step_scaler._per_optimizer_states[id(optimizer)]
# Check if any gradients are inf/nan
inf_grad_cnt = sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
if inf_grad_cnt == 0:
# if valid graident, apply sam_first_step
optimizer.first_step(zero_grad=True, mixed_precision=True)
sam_first_step_applied = True
else:
# if invalid graident, skip sam and revert to single optimization step
optimizer.zero_grad()
sam_first_step_applied = False
# Update the scaler with no impact on the model (weights or gradient). This update step
# resets the optimizer_state["found_inf_per_device"]. So, it is applied after computing
# inf_grad_cnt. Note that zero_grad() has no impact on the update() operation,
# because update() leverage optimizer_state["found_inf_per_device"]
first_step_scaler.update()
disable_running_stats(model)
# Second forward step
with autocast():
output = model(data)
loss = F.nll_loss(output, target)
second_step_scaler.scale(loss).backward()
if sam_first_step_applied:
# If sam_first_step was applied, apply the 2nd step
optimizer.second_step(mixed_precision=True)
second_step_scaler.step(optimizer)
I tried something like this:
def train(
args, model, device, train_loader, optimizer, first_step_scaler, second_step_scaler, epoch, gradient_acc=2
):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
enable_running_stats(model)
# First forward step
with autocast():
output = model(data)
loss = F.nll_loss(output, target)
loss = loss / gradient_acc
first_step_scaler.scale(loss).backward()
# We unscale manually for two reasons: (1) SAM's first-step adds the gradient
# to weights directly. So gradient must be unscaled; (2) unscale_ checks if any
# gradient is inf and updates optimizer_state["found_inf_per_device"] accordingly.
# We use optimizer_state["found_inf_per_device"] to decide whether to apply
# SAM's first-step or not.
first_step_scaler.unscale_(optimizer)
optimizer_state = first_step_scaler._per_optimizer_states[id(optimizer)]
# Check if any gradients are inf/nan
inf_grad_cnt = sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
if inf_grad_cnt == 0:
# if valid graident, apply sam_first_step
optimizer.first_step(zero_grad=True, mixed_precision=True)
sam_first_step_applied = True
else:
# if invalid graident, skip sam and revert to single optimization step
optimizer.zero_grad()
sam_first_step_applied = False
# Update the scaler with no impact on the model (weights or gradient). This update step
# resets the optimizer_state["found_inf_per_device"]. So, it is applied after computing
# inf_grad_cnt. Note that zero_grad() has no impact on the update() operation,
# because update() leverage optimizer_state["found_inf_per_device"]
first_step_scaler.update()
disable_running_stats(model)
# Second forward step
with autocast():
output = model(data)
loss = F.nll_loss(output, target)
loss = loss / gradient_acc
second_step_scaler.scale(loss).backward()
if sam_first_step_applied:
# If sam_first_step was applied, apply the 2nd step
optimizer.second_step(mixed_precision=True)
if not (batch_idx + 1) % gradient_acc != 0:
second_step_scaler.step(optimizer)
second_step_scaler.update()
optimizer.zero_grad()
But I noticed this makes my loss increasing rather than decreasing, anyone have any idea how to improvise this?