0

I want to know the performance characteristics of xla::Reshape. Specifically, I can imagine that it could be implemented by simply remapping XlaOp metadata e.g. addresses, rather than creating a whole new XlaOp. Alternatively, does XLA's fusion or some other technique essentially make it cheap?

The reason I ask is because I'm working out how to map a function over a tensor, for example a function

f : [p, q] -> [r]

over an [n, m, p, q] to get an [n, m, r]. One option I have is to flatten leading dimensions, require the function allows a single leading dimension, e.g.

f' : [n, p, q] -> [n, r]

then reshape the result as required. However, this is only feasible if flattening and expanding is performant.

Tagging Jax because I imagine it's the same story there. Of course Jax has vmap/pmap which make this unnecessary.

joel
  • 6,359
  • 2
  • 30
  • 55

1 Answers1

0

It depends on the physical layout chage of the tensor.

Usually, XLA's reshape accompanies tensor's physical layout change, and may cause more cost compared to bitcast operation (which does not change the physical layout, thus making almost no overhead).

However, if reshape does not accompany the logical layout, its cost may be cheap.

low cost reshape