I have a pytree containing arrays that have different shapes, for example it contains:
observations
of shape(5, 3, 250, 23)
dones
of shape(5, 3, 250)
I want to reshape my pytree so that the first two dimensions are merged, which would give something like (15, 250, ...)
for every object in my pytree.
I usually use tree_map
to work on my pytrees but this time I struggle to make it work, I tried:
jax.tree_map(lambda x: jnp.reshape(x, newshape=(15, -1)), my_pytree)
it works well for dones
but it merges the last dimensions for observations
, leading to an array of shape (15, 5750)
(and I'd want it to be (15, 250, 23)
here).
Note: I cannot modify the definition of my pytree, I have to work with this structure.