0

I have a 2D JAX array containing an image.

For each pixel P[y, x] of the image, I would like to loop over all pixels P[y, x-i] to the left of that pixel and reduce those to a single value. The exact reduction computation involves finding a particular maximum over a weighted sum involving those pixels' values, as well as i and x. Therefore, the result (or any intermediate results) for P[y, x] can't be reused for P[y, x+1] either; this is an O(x²y) operation overall.

Can I accomplish this somewhat efficiently in JAX? If so, how?

sk29910
  • 2,326
  • 1
  • 18
  • 23

1 Answers1

1

JAX does not provide any native tool to do this sort of operation for an arbitrary function. It can be done via lax.scan or perhaps jnp.cumsum for functions where each successive value can be computed from the last, but it sounds like that is not the case here.

I believe the best you can do is to combine vmap with Python for-loops to achieve what you want: just be aware that during JIT compilation JAX will flatten all for loops, so if your image size is very large, the compilation time will be long. Here's a short example:

import jax.numpy as jnp
from jax import vmap

def reduction(x):
  # some 1D reduction
  assert x.ndim == 1
  return len(x) + jnp.sum(x)

def cumulative_apply(row, reduction=reduction):
  return jnp.array([reduction(row[:i]) for i in range(1, len(row) + 1)])

P = jnp.arange(20).reshape(4, 5)
result = vmap(cumulative_apply)(P)
print(result)
# [[ 1  3  6 10 15]
#  [ 6 13 21 30 40]
#  [11 23 36 50 65]
#  [16 33 51 70 90]]
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Thank you for the example. I had tried that approach already and was hoping for more efficient option, but it looks like JAX isn't designed for this sort of thing. I'll be using a different approach to solve my problem instead. – sk29910 Apr 22 '22 at 19:38