0

The context of the problem is that I have a resnet model in Jax (basically NumPy), and I take the gradient of an image with respect to its class prediction. This gives me a gradient vector, g, which I then want to normalize. The trouble is, the magnitudes of the components, g[i], are such that g[i]**2 == 0, meaning that just dividing by np.linalg.norm(g) gives a value of 0, hence giving me nans.

What I've done so far is just checking if the norm is 0 then multiplying by some constant factor, as in (g = np.where(np.linalg.norm(g) < 1e-20, g * 1e20, g)).

Was thinking maybe I should instead divide by the smallest nonzero element then normalize. Does anyone have ideas on how to properly normalize this vector?

Tonechas
  • 13,398
  • 16
  • 46
  • 80
  • Can u share the implementation you tried out ? – Asnim P Ansari May 17 '21 at 16:31
  • Also am trying out doing something like g = np.where(np.linalg.norm(g) == 0, g * np.finfo(np.float32).max, g), where np.finfo(np.float32).max gives the max value for a float32 in NumPy. Note that if the norm of g is 0, then we won't have to worry about overflow as individual components << 1. – Parker Jou May 18 '21 at 17:55

0 Answers0