torch.matmul
in pytorch has functions of broadcasting, which may consume too much memory.I am looking for efficient implementations to prevent excessive memory usage.
For example,the input tensor has size as
adj.size()==[1,3000,3000]
s.size()==torch.Size([235, 3000, 10])
s.transpose(1, 2).size()==torch.Size([235, 10, 3000])
The task is to calculate
link_loss = adj - torch.matmul(s, s.transpose(1, 2)) #
link_loss = torch.norm(link_loss, p=2)
The original code lies in the torch extension package torch_geometric
.It is in the definition of the function dense_diff_pool.
torch.matmul(s, s.transpose(1,2))
will consume excessive memory(my computer has only memory space around 2GB) ,raising the error:
Traceback (most recent call last):
File "", line 1, in torch.matmul(s, s.transpose(1, 2))
RuntimeError: $ Torch: not enough memory: you tried to allocate 7GB. Buy new RAM! at ..\aten\src\TH\THGeneral.cpp:201
The original code by the author of the package contains torch.matmul(s, s.transpose(1, 2)).size()==[235,3000,3000]
which is larger than 7GB.
My attemptation is that I tried to use a for
iteration
batch_size=235
link_loss=torch.sqrt(torch.stack([torch.norm(adj - torch.matmul(s[i], s[i].transpose(0, 1)), p=2)**2 for i in range(batch_size)]).sum(dim=0))
This for
loop is known to be slower than using broadcasting or other pytorch built-in functions.
Question:
Is there any faster implementation, better than using [ ... for ...].
I am a newbie on learning pytorch. Thanks.