To save memory I recommend using torch.einsum
:
We can make use of the trigonometric identity
cos(x-y) = cos(x)*cos(y) + sin(x)*sin(y)
In this case we can apply einsum
where the usual summing will be the averaging, and the +
between the two produces will be another operation later, so in short
xs, ys = torch.sin(x), torch.sin(y)
xc, yc = torch.cos(x), torch.cos(y)
# use einsum for sin/cos products and averaging sum, use + for sum of products:
out = (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
While measuring the memory consumption is a little bit tedious, I resorted to just measuring time as a proxy. Here you can see your original method and my proposal for various sizes of inputs. (The script for generating these plots is attached below.)

import matplotlib.pyplot as plt
import torch
import time
def main():
ns = torch.logspace(1, 3.2, 20).to(torch.long)
tns = []; tes = []
for n in ns:
tn, te = compare(n)
tns.append(tn); tes.append(te)
plt.loglog(ns, tns, ':.'); plt.loglog(ns, tes, '.-'); plt.loglog(ns, 1e-6*ns**1, ':'); plt.loglog(ns, 1e-6*ns**2, ':'); plt.legend(['naive', 'einsum', 'x^1', 'x^2']);
plt.show()
def compare(n):
batch = a = b = n
x = torch.zeros((batch, b)) # (batch , 1, B)
y = torch.zeros((a, b)) # (1, , A, B)
t = time.perf_counter(); ra = af(x.unsqueeze(1), y.unsqueeze(0)); print('naive method', tn := time.perf_counter() - t)
t = time.perf_counter(); rb = bf(x, y); print('einsum method', te := time.perf_counter() - t)
print((ra-rb).abs().max()) # verify we have same results
return tn, te
def af(x, y):
return torch.cos(x - y).mean(dim=2)
def bf(x, y):
xs, ys = torch.sin(x), torch.sin(y)
xc, yc = torch.cos(x), torch.cos(y)
return (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
main()