0

I have a pytree containing arrays that have different shapes, for example it contains:

  • observations of shape (5, 3, 250, 23)
  • dones of shape (5, 3, 250)

I want to reshape my pytree so that the first two dimensions are merged, which would give something like (15, 250, ...) for every object in my pytree.

I usually use tree_map to work on my pytrees but this time I struggle to make it work, I tried:

jax.tree_map(lambda x: jnp.reshape(x, newshape=(15, -1)), my_pytree)

it works well for dones but it merges the last dimensions for observations, leading to an array of shape (15, 5750) (and I'd want it to be (15, 250, 23) here).

Note: I cannot modify the definition of my pytree, I have to work with this structure.

Valentin Macé
  • 1,150
  • 1
  • 10
  • 25

1 Answers1

0

Sorry for the post it was kind of trivial. I post the answer just in case:

jax.tree_map(lambda x: jnp.reshape(x, newshape=(15, *x.shape[2:])),my_pytree)
Valentin Macé
  • 1,150
  • 1
  • 10
  • 25