I want to know the performance characteristics of xla::Reshape
. Specifically, I can imagine that it could be implemented by simply remapping XlaOp
metadata e.g. addresses, rather than creating a whole new XlaOp
. Alternatively, does XLA's fusion or some other technique essentially make it cheap?
The reason I ask is because I'm working out how to map a function over a tensor, for example a function
f : [p, q] -> [r]
over an [n, m, p, q] to get an [n, m, r]. One option I have is to flatten leading dimensions, require the function allows a single leading dimension, e.g.
f' : [n, p, q] -> [n, r]
then reshape the result as required. However, this is only feasible if flattening and expanding is performant.
Tagging Jax because I imagine it's the same story there. Of course Jax has vmap
/pmap
which make this unnecessary.