I have a pytorch tensor T
with shape (batch_size, window_size, filters, 3, 3)
and I would like to pool the tensor by trace. Specifically, I would like to obtain a tensor T_pooled
of size (batch_size, window_size//2, filters, 3, 3)
by comparing the trace of paired frames. For example, if window_size=4
, then we would compare the trace of T[i,0,k,3,3]
and T[i,1,k,3,3]
and select the subtensor with the smaller trace to be T_pooled[i,0,k,3,3]
. Similarly, compare T[i,2,k,3,3]
and T[i,3,k,3,3]
to obtain T_pooled[i,1,k,3,3]
.
This can be done by looping over i
and k
, but that is very slow and inefficient. Is there a way to vectorize this pooling operation to speed it up?
Edit: Here is what I have tried so far. It uses list comprehension and for loops. It takes approximately 2.5s to run on a tensor of size (128,120,22,3,3).
def TPL_Pairwise(x):
x_pooled=torch.zeros(x.shape[0],x.shape[1]//2,x.shape[2],x.shape[3], x.shape[4])
#compute tensorized trace
trace=torch.einsum('ijkll->ijkl', x).sum(-1)
for i in range(x.shape[0]):
for j in range(x.shape[2]):
keep=[ x[i,k,j] if trace[i,k,j] <= trace[i,k+1,j] else x[i,k+1,j] for k in range(0,x.shape[1],2)]
x_pooled[i,:,j]=torch.stack(keep)
return x_pooled