I'm trying to train a variation of a u-net on a TPU and there appear to be 3 ops that are using 24 gigs of memory. Given that the network is quite large, I have no way of knowing where they are. How do you figure out the actual operations that these opaque stack traces refer to?
RuntimeError: Compilation failed: Compilation failure: Ran out of memory in memory space hbm. Used 27.90G of 16.00G hbm. Exceeded hbm capacity by 11.90G.
Total hbm usage >= 27.90G:
reserved 528.00M
program 27.38G
arguments unknown size
Output size unknown.
Program hbm requirement 27.38G:
reserved 12.0K
scoped 1.0K
HLO temp 27.38G (5.6% utilization, 0.0% fragmentation (1.14M))
Largest program allocations in hbm:
1. Size: 8.00G
Operator: op_type="CrossReplicaSum" op_name="tpu_139655909282424/CrossReplicaSum"
Shape: f32[256,512,128,2]{3,2,1,0}
Unpadded size: 128.00M
Extra memory due to padding: 7.88G (64.0x expansion)
XLA label: %cross-replica-sum = f32[256,512,128,2]{3,2,1,0} cross-replica-sum(f32[256,512,128,2]{3,2,1,0} %bitcast.1), replica_groups={{0,1,2,3,4,5,6,7}}, barrier="custom:0", to_apply=%sum.902, metadata={op_type="CrossReplicaSum" op_name="tpu_139655909282424/CrossRep...
Allocation type: HLO temp
==========================
2. Size: 8.00G
Operator: op_type="Mul" op_name="tpu_139655909282424/mul_1"
Shape: f32[8,32,512,128,2]{4,3,2,1,0}
Unpadded size: 128.00M
Extra memory due to padding: 7.88G (64.0x expansion)
XLA label: %fusion.4 = (f32[8,32,512,128,2]{4,3,2,1,0}, f32[8,32,512,128,2]{4,3,2,1,0}) fusion(f32[8]{0} %fusion.1265, f32[32,512,128,2]{3,2,1,0} %reshape.319, f32[32,512,128,2]{3,2,1,0} %copy.5), kind=kLoop, calls=%fused_computation.4, metadata={op_type="Mul" op_nam...
Allocation type: HLO temp
==========================
3. Size: 8.00G
Operator: op_type="Mul" op_name="tpu_139655909282424/mul_1"
Shape: f32[8,32,512,128,2]{4,3,2,1,0}
Unpadded size: 128.00M
Extra memory due to padding: 7.88G (64.0x expansion)
XLA label: %fusion.4 = (f32[8,32,512,128,2]{4,3,2,1,0}, f32[8,32,512,128,2]{4,3,2,1,0}) fusion(f32[8]{0} %fusion.1265, f32[32,512,128,2]{3,2,1,0} %reshape.319, f32[32,512,128,2]{3,2,1,0} %copy.5), kind=kLoop, calls=%fused_computation.4, metadata={op_type="Mul" op_nam...
Allocation type: HLO temp
==========================