0

recently I build my first model in Flax. The forward pass worked fine, but i experienced OOM errors during the backward pass.

Originally I had split my model into several small classes, each of which implemented as its own flax model inheriting from linen.module. For debugging I unified all of these parts in one flax model. This reduced the memory footprint during the backward pass drastically.

Can someone explain if this is an expected behavior and if so, what is the source of the memory overhead in the backwards pass when using multiple small classes instead of one large one?

Thanks in advance and best regards.

Simon P.
  • 105
  • 7

0 Answers0