2

torch.matmul in pytorch has functions of broadcasting, which may consume too much memory.I am looking for efficient implementations to prevent excessive memory usage.

For example,the input tensor has size as adj.size()==[1,3000,3000] s.size()==torch.Size([235, 3000, 10]) s.transpose(1, 2).size()==torch.Size([235, 10, 3000]) The task is to calculate

link_loss = adj - torch.matmul(s, s.transpose(1, 2)) #
link_loss = torch.norm(link_loss, p=2)

The original code lies in the torch extension package torch_geometric.It is in the definition of the function dense_diff_pool. torch.matmul(s, s.transpose(1,2)) will consume excessive memory(my computer has only memory space around 2GB) ,raising the error:

Traceback (most recent call last):

File "", line 1, in torch.matmul(s, s.transpose(1, 2))

RuntimeError: $ Torch: not enough memory: you tried to allocate 7GB. Buy new RAM! at ..\aten\src\TH\THGeneral.cpp:201

The original code by the author of the package contains torch.matmul(s, s.transpose(1, 2)).size()==[235,3000,3000] which is larger than 7GB.

My attemptation is that I tried to use a for iteration

batch_size=235
link_loss=torch.sqrt(torch.stack([torch.norm(adj - torch.matmul(s[i], s[i].transpose(0, 1)), p=2)**2 for i in range(batch_size)]).sum(dim=0))

This for loop is known to be slower than using broadcasting or other pytorch built-in functions. Question: Is there any faster implementation, better than using [ ... for ...]. I am a newbie on learning pytorch. Thanks.

Anubhav Singh
  • 8,321
  • 4
  • 25
  • 43
Ferret Zhang
  • 306
  • 3
  • 8

0 Answers0