I have a 2D JAX array containing an image.
For each pixel P[y, x]
of the image, I would like to loop over all pixels P[y, x-i]
to the left of that pixel and reduce those to a single value. The exact reduction computation involves finding a particular maximum over a weighted sum involving those pixels' values, as well as i
and x
. Therefore, the result (or any intermediate results) for P[y, x]
can't be reused for P[y, x+1]
either; this is an O(x²y) operation overall.
Can I accomplish this somewhat efficiently in JAX? If so, how?