As it seems torchvision.ops.Conv2dNormActivation
function only takes the Activation functions defined under torch.nn
due to the declaration of the activation_layer
argument as Callable[..., torch.nn.Module]
in the source code.
I tried defining a custom activation function (example) like
class ExpExp(nn.Module):
__constants__ = ['inplace']
inplace: bool
def __init__(self,inplace: bool = False):
super(TanhExp, self).__init__()
def forward(self, x):
return x*torch.exp(torch.exp(x))
def extra_repr(self):
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
Has called it as
Conv2dNormActivation(
in_channels,
out_channels,
kernel_size=(h,w),
stride=stride,
padding = padding,
norm_layer=norm_layer,
activation_layer=ExpExp(),
inplace=None
))
It returned an error TypeError: forward() missing 1 required positional argument: 'x'
I think I have to give input x
to the ExpExp()
activation function when calling it in the Conv2dNoemActivation
. But how do I do that? Is there a way we can specify custom activation function?