TF provides the TensorArray
to make automatic iteration and stacking efficient in scan
or while_loop
.
The naive variant with gathering and concatenating or dynamic updates would be inefficient with backprop, because backprop would keep copies of the full array in each iteration. E.g. assume you are collecting ys
of shape [T,B,F], and iterating over t in [0,...,T-1]. Now two possible naive variants:
You don't know T in advance. You allocate the initial
ys
tensor of shape [0,B,F], and each iteration, you concatenate a new vector [B,F], extended as [1,B,F] to it, so in each step t, the currentys
is of shape [t,B,F].You know T in advance. You can allocate the initial
ys
tensor of shape [T,B,F]. In each iteration, you updateys[t]
(e.g.tensor_scatter_nd_update
).
I have seen the concat variant being used for self-attention implementations.
I was checking JAX while_loop and scan and it seems it does not have TensorArray
but instead uses dynamic_index_in_dim
/dynamic_slice
and dynamic_update_index_in_dim
/dynamic_update_slice
(like tensor_scatter_nd_update
), which is like variant 2.
Without considering backprop, variant 2 can be efficient if it would update inplace, actually more so than TensorArray
. But if it does not update inplace for some reason, you get O(T^2) runtime. Variant 1 would also likely lead to O(T^2) runtime, unless it can be very clever and having preallocated a tensor which is bigger. Then it might get away with O(T log T) runtime, similar to C++ std::vector
. But I very much doubt that.
When considering backprop, it is much worse, unless there are some clever optimizations happening. But for the standard case, it would need to store a copy of the full ys
tensor in every iteration, for the use of backprop. So it means you get T times a copy of ys
. This means O(T^2) memory consumption.
With TF TensorArray
, this is not the case, as each tensor ys[t]
is treated separately. It is efficient and only has O(T) runtime and memory consumption, even with backprop.
So, my question is: Is JAX scan
really inefficient like I described, esp for the case of backprop? Or if not, why not? How does it avoid the O(T^2) memory consumption in backprop? Is this some automatic optimization? How does it work?
I was implementing some script to test this, and then measuring memory and runtime, for different sequence lengths n. I get this:
So, it seems the memory consumption is actual linear, but the runtime seems quadratic.
Somewhat related question in JAX #3106 on TensorArray equivalent, but it doesn't really answer my question here on efficiency.
The same question in JAX #15906 with some answers.