I am trying to run something using JAX (which works with only 1 GPU). But when I increase the GPU to 4 (32 CPU). I get this error:
INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:70: NCCL operation ncclAllReduce(send_buffer, recv_buffer, element_count, dtype, reduce_op, comm, gpu_stream) failed: unhandled cuda error
Is there any configuration I should do?
Many thanks