0

TF provides the TensorArray to make automatic iteration and stacking efficient in scan or while_loop.

The naive variant with gathering and concatenating or dynamic updates would be inefficient with backprop, because backprop would keep copies of the full array in each iteration. E.g. assume you are collecting ys of shape [T,B,F], and iterating over t in [0,...,T-1]. Now two possible naive variants:

  1. You don't know T in advance. You allocate the initial ys tensor of shape [0,B,F], and each iteration, you concatenate a new vector [B,F], extended as [1,B,F] to it, so in each step t, the current ys is of shape [t,B,F].

  2. You know T in advance. You can allocate the initial ys tensor of shape [T,B,F]. In each iteration, you update ys[t] (e.g. tensor_scatter_nd_update).

I have seen the concat variant being used for self-attention implementations.

I was checking JAX while_loop and scan and it seems it does not have TensorArray but instead uses dynamic_index_in_dim/dynamic_slice and dynamic_update_index_in_dim/dynamic_update_slice (like tensor_scatter_nd_update), which is like variant 2.

Without considering backprop, variant 2 can be efficient if it would update inplace, actually more so than TensorArray. But if it does not update inplace for some reason, you get O(T^2) runtime. Variant 1 would also likely lead to O(T^2) runtime, unless it can be very clever and having preallocated a tensor which is bigger. Then it might get away with O(T log T) runtime, similar to C++ std::vector. But I very much doubt that.

When considering backprop, it is much worse, unless there are some clever optimizations happening. But for the standard case, it would need to store a copy of the full ys tensor in every iteration, for the use of backprop. So it means you get T times a copy of ys. This means O(T^2) memory consumption.

With TF TensorArray, this is not the case, as each tensor ys[t] is treated separately. It is efficient and only has O(T) runtime and memory consumption, even with backprop.

So, my question is: Is JAX scan really inefficient like I described, esp for the case of backprop? Or if not, why not? How does it avoid the O(T^2) memory consumption in backprop? Is this some automatic optimization? How does it work?


I was implementing some script to test this, and then measuring memory and runtime, for different sequence lengths n. I get this: mem, runtime

So, it seems the memory consumption is actual linear, but the runtime seems quadratic.


Somewhat related question in JAX #3106 on TensorArray equivalent, but it doesn't really answer my question here on efficiency.

The same question in JAX #15906 with some answers.

Albert
  • 65,406
  • 61
  • 242
  • 386

0 Answers0