I am an intermediate learner in PyTorch and in some recent cases, I have seen people use the torch.inference_mode()
instead of the famous torch.no_grad()
while validating your trained agent in reinforcement learning (RL) experiments. I checked the documentation and they have a table that consists of two flags to disable the gradient computation. And to be honest, if I read the description it sounds exactly the same to me. Has someone figured out an explanation?

- 807
- 8
- 16
-
1https://pytorch.org/docs/1.12/notes/autograd.html#inference-mode – user2357112 Oct 25 '22 at 08:29
-
That doesn't answers to the question as to why PyTorch would keep two different flags and not combine them into one and what is the difference in the PyTorch backend. It just explains what is the functionality of torch.infererence_mode – Satya Prakash Dash Nov 02 '22 at 03:35
1 Answers
So I have been scraping the web for a few days and I think I got my explanation. The torch.inference()
mode has been added as an even more optimized way of doing inference with PyTorch (versus the torch.no_grad()
). I listened to the PyTorch podcast and they have an explanation as to why there exists to different flags.
Version control of tensors: Let's say, you have a code in PyTorch and you have used it to train an agent. When you do
torch.no_grad()
and just run inference on the trained model, there are still some functionalities of PyTorch like version counting of tensor which are still in play, which gets allocated every time a tensor is created and increments (version bumps) when you mutate that specific tensor. Keeping a check of all the versions of all the tensors requires extra cost from computation and we can't just get rid of them as we have to keep an eye out for tensor mutations, either (directly) to that specific tensor or (indirectly) aliasing to some other tensor which is saved for backward computation.View Tracking of Tensor: Pytorch tensors are strided. What that means is PyTorch uses stride in the backend for indexing, which can be used if you want to directly access specific elements in the memory block. But in the case of
torch.autograd
, what if you took a tensor and created a newview
, and mutated it with a tensor that is associated with the backward computation? Withtorch.no_grad
they keep record of someview
metadata which is required to keep track of which tensors require gradients and which not. This also add up an extra overhead to you computation resources.
So torch.autograd
check for these changes which don't get tracked when you switch to torch.inference_mode()
(instead of torch.no_grad()
) and if you code is not exploiting the above two points then inference mode works and reduces the code execution time. (PyTorch dev team says they have seen a bump of 5-10% while deploying models in production at Facebook.)

- 807
- 8
- 16