4

I'm trying to train a model in mixed precision. However, I want a few of the layers to be in full precision for stability reasons. How do I force an individual layer to be float32 when using torch.autocast? In particular, I'd like for this to be onnx compileable.

Is it something like:

with torch.autocast(device_type='cuda', enabled=False, dtype=torch.float16):
    out = my_unstable_layer(inputs.float())

Edit:

Looks like this is indeed the official method. See the torch docs.

Luke
  • 6,699
  • 13
  • 50
  • 88

1 Answers1

4

I think the motivation of torch.autocast is to automate the reduction of precision (not the increase).

If you have functions that need a particular dtype, you should consider using, custom_fwd

import torch
@torch.cuda.amp.custom_fwd(cast_inputs=torch.complex128)
def get_custom(x):
    print('  Decorated function received', x.dtype)
def regular_func(x):
    print('  Regular function received', x.dtype)
    get_custom(x)

x = torch.tensor(0.0, dtype=torch.half, device='cuda')
with torch.cuda.amp.autocast(False):
    print('autocast disabled')
    regular_func(x)
with torch.cuda.amp.autocast(True):
    print('autocast enabled')
    regular_func(x)
autocast disabled
  Regular function received torch.float16
  Decorated function received torch.float16
autocast enabled
  Regular function received torch.float16
  Decorated function received torch.complex128

Edit: Using torchscript

I am not sure how much you can rely on this, due to a comment in the documentation. However the comment is apparently outdated.

Here is an example where I trace the model with autocast enabled, feeze it and then I use it and the value is indeed cast to the specified type

class Cast(torch.nn.Module):    
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float64)
    def forward(self, x):
        return x

with torch.cuda.amp.autocast(True):
    model = torch.jit.trace(Cast().eval(), x)
model = torch.jit.freeze(model)

x = torch.tensor(0.0, dtype=torch.half, device='cuda')
print(model(x).dtype)
torch.float64

But I suggest you to validate this approach before using it for a serious application.

Bob
  • 13,867
  • 1
  • 5
  • 27
  • Do you know if this will work with torchscript? – Luke Aug 29 '22 at 16:42
  • Not sure, maybe [this](https://pytorch.org/docs/stable/amp.html#:~:text=For%20now%2C%20we%20suggest%20to%20disable%20the%20Jit%20Autocast%20Pass) means that don't. – Bob Aug 29 '22 at 18:16
  • When I use the approach I listed in my question above, it does appear to work in torch. It's just in torchscript that it fails. So I don't think the decorator is needed – Luke Aug 29 '22 at 18:26
  • 1
    Check the example I appended to the answer. Does it help? – Bob Aug 29 '22 at 18:38