When using the DDP backend, there's a separate process running for every GPU. They don't have access to each other's data, but there are a few special operations (reduce, all_reduce, gather, all_gather) that make the processes synchronize. When you use such operations on a tensor, the processes will wait for each other to reach the same point and combine their values in some way, for example take the sum from every process.
In theory it's possible to gather all data from all processes and then calculate the metric in one process, but this is slow and prone to problems, so you want to minimize the data that you transfer. The easiest approach is to calculate the metric in pieces and then for example take the average. self.log()
calls will do this automatically when you use sync_dist=True
.
If you don't want to take the average over the GPU processes, it's also possible to update some state variables at each step, and after the epoch synchronize the state variables and calculate your metric from those values. The recommended way is to create a class that uses the Metrics API, which recently moved from PyTorch Lightning to the TorchMetrics project.
If it's not enough to store a set of state variables, you can try to make your metric gather all data from all the processes. Derive your own metric from the Metric base class, overriding the update()
and compute()
methods. Use add_state("data", default=[], dist_reduce_fx="cat")
to create a list where you collect the data that you need for calculating the metric. dist_reduce_fx="cat"
will cause the data from different processes to be combined with torch.cat()
. Internally it uses torch.distributed.all_gather. The tricky part here is that it assumes that all processes create identically-sized tensors. If the sizes don't match, syncing will hang indefinitely.