1

I would like to implement the following activation function in pytorch:

x = T if abs(x)>T else x

I could do something close with torch.clamp(min=-T, max=T) but it's not exactly the behavior I want (this would behave the same as above for x>-T but would return -T for x<-T). Is there any torch function that could help me achieve this?

Noé Achache
  • 195
  • 2
  • 9
  • There's a discontinuity at x = -T, that seems to me like an issue for optimization reasons (for one thing the function has no derivative or sub-derivative at x=-T). – jodag Jun 10 '20 at 20:49
  • I know but this was suggested by some paper: we don't actually care about the truncated values so whether they are -T or T does not change much. Setting them both to the same value allows to reduce the variance of the feature map (c.f. https://arxiv.org/pdf/1912.06540.pdf) – Noé Achache Jun 10 '20 at 21:41

1 Answers1

2

torch.where does exactly that:

torch.where

x = torch.where(torch.abs(x) > T, T, x)
Michael Jungo
  • 31,583
  • 3
  • 91
  • 84