I am reading some code with python, which is about probability distribution. But in many cases, instead of using a framework like PyTorch to compute distribution, it uses the following code
def sample_energy_0(self, y, M):
device = next(self.parameters()).device
x = torch.randn(M, y.shape[0], self.latent_dim).to(device)
return x
def energy_0(self, x, y):
return (x**2).sum(axis=2, keepdims=True) / 2
How does the function sum(axis=2)
work here?