0

As it seems torchvision.ops.Conv2dNormActivation function only takes the Activation functions defined under torch.nn due to the declaration of the activation_layer argument as Callable[..., torch.nn.Module] in the source code.

I tried defining a custom activation function (example) like

class ExpExp(nn.Module):
    __constants__ = ['inplace']
    inplace: bool
    def __init__(self,inplace: bool = False):
        super(TanhExp, self).__init__()
  
    def forward(self, x):
        return x*torch.exp(torch.exp(x))
    
    def extra_repr(self): 
        inplace_str = 'inplace=True' if self.inplace else ''
        return inplace_str

Has called it as

Conv2dNormActivation(
                    in_channels,
                    out_channels,
                    kernel_size=(h,w),
                    stride=stride,
                    padding = padding,
                    norm_layer=norm_layer,
                    activation_layer=ExpExp(),
                    inplace=None
                ))

It returned an error TypeError: forward() missing 1 required positional argument: 'x'

I think I have to give input x to the ExpExp() activation function when calling it in the Conv2dNoemActivation. But how do I do that? Is there a way we can specify custom activation function?

0 Answers0