2

I am trying to use WandB gradient visualization to debug the gradient flow in my neural net on Google Colab. Without WandB logging, the training runs without error, taking up 11Gb/16GB on the p100 gpu. However, adding this line wandb.watch(model, log='all', log_freq=3) causes a cuda out of memory error.

How does WandB logging create extra GPU memory overhead?

Is there some way to reduce the overhead?

--adding training loop code--

learning_rate = 0.001
num_epochs = 50

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MyModel()

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
wandb.watch(model, log='all', log_freq=3)

for epoch in range(num_epochs):
    train_running_loss = 0.0
    train_accuracy = 0.0

    model = model.train()

    ## training step
    for i, (name, output_array, input) in enumerate(trainloader):
        
        output_array = output_array.to(device)
        input = input.to(device)
        comb = torch.zeros(1,1,100,1632).to(device)

        ## forward + backprop + loss
        output = model(input, comb)

        loss = my_loss(output, output_array)

        optimizer.zero_grad()

        loss.backward()

        ## update model params
        optimizer.step()

        train_running_loss += loss.detach().item()

        temp = get_accuracy(output, output_array)

        print('check 13')
        !nvidia-smi | grep MiB | awk '{print $9 $10 $11}'

        train_accuracy += temp     

-----edit-----

I think WandB is creating an extra copy of the gradient during logging preprocessing. Here is the traceback:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-13de83557b55> in <module>()
     60         get_ipython().system("nvidia-smi | grep MiB | awk '{print $9 $10 $11}'")
     61 
---> 62         loss.backward()
     63 
     64         print('check 10')

4 frames
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    253                 create_graph=create_graph,
    254                 inputs=inputs)
--> 255         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    256 
    257     def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    147     Variable._execution_engine.run_backward(
    148         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 149         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    150 
    151 

/usr/local/lib/python3.7/dist-packages/wandb/wandb_torch.py in <lambda>(grad)
    283             self.log_tensor_stats(grad.data, name)
    284 
--> 285         handle = var.register_hook(lambda grad: _callback(grad, log_track))
    286         self._hook_handles[name] = handle
    287         return handle

/usr/local/lib/python3.7/dist-packages/wandb/wandb_torch.py in _callback(grad, log_track)
    281             if not log_track_update(log_track):
    282                 return
--> 283             self.log_tensor_stats(grad.data, name)
    284 
    285         handle = var.register_hook(lambda grad: _callback(grad, log_track))

/usr/local/lib/python3.7/dist-packages/wandb/wandb_torch.py in log_tensor_stats(self, tensor, name)
    219         # Remove nans from tensor. There's no good way to represent that in histograms.
    220         flat = flat[~torch.isnan(flat)]
--> 221         flat = flat[~torch.isinf(flat)]
    222         if flat.shape == torch.Size([0]):
    223             # Often the whole tensor is nan or inf. Just don't log it in that case.

RuntimeError: CUDA out of memory. Tried to allocate 4.65 GiB (GPU 0; 15.90 GiB total capacity; 10.10 GiB already allocated; 717.75 MiB free; 14.27 GiB reserved in total by PyTorch)

---update----

Indeed, commenting out the offending line flat = flat[~torch.isinf(flat)]

gets the logging step to just barely fit into the GPU memory.

Ambrose
  • 183
  • 1
  • 14

0 Answers0