So I have a program that have multiple functions with its own jax calls and here is the main function:
@partial(jax.jit, static_argnames=("numberOfVoxels",))
def process_valid_voxels(numberOfVoxels, voxelPositions, voxelLikelihoods, ps, t, M, tmp):
func = lambda tmp_val: process_voxel(tmp_val, voxelPositions, voxelLikelihoods, ps, t, M, tmp)
ys, likelihoods = jax.vmap(func)(jnp.arange(numberOfVoxels))
return ys, likelihoods
This is the output of ys and likelihoods:
(Pdb) ys
Traced<ShapedArray(int32[3700,3,1])>with<DynamicJaxprTrace(level=3/0)>`
likelihoods
Traced<ShapedArray(float32[3700,7,1])>with<DynamicJaxprTrace(level=3/0)>
I want to get values from traced arrays ys, likelihoods so that I can modify them. I have tried using the io_callback function:
def callback1(x):
return jax.experimental.io_callback(process_voxel, x , x)
a = callback1(jnp.arange(numberOfVoxels))
but the output is the same except for the shape of the array:
Traced<ShapedArray(int32[3700])>with<DynamicJaxprTrace(level=3/0)>