1

Let's say I have a tensor shaped (1, 64, 128, 128) and I want to create a tensor of shape (1, 64, 255) holding the sums of all diagonals for every (128, 128) matrix (there are 1 main, 127 below, 127 above diagonals so in total 255). What I am currently doing is the following:

x = torch.rand(1, 64, 128, 128)

diag_sums = torch.zeros(1, 64, 255)
j = 0
for k in range(-127, 128):
    diag_sums[j, :, k + 127] = torch.diagonal(x, offset=k, dim1=-2, dim2=-1).sum(dim=2)
This is obviously very slow, since it is using Python loops and is not done in parallel with respect to k.

I don't think this can be done using torch.diagonal since the function explicitly uses a single int for the offset parameter. If I could pass a list there, this would work, but I guess it would be complicated to implement (requiring changes in PyTorch itself).

I think it could be possible to implement this using torch.einsum, but I cannot think of a way to do it.

So this is my question: how do I get the tensor described above?

Hristo Vrigazov
  • 1,357
  • 2
  • 12
  • 20

3 Answers3

1

Have you considered using torch.nn.functional.conv2d?
You can sum the diagonals with a diagonal filter sliding across the tensor with appropriate zero padding.

import torch
import torch.nn.functional as nnf

# construct a diagonal filter using `eye` function, shape it appropriately
f = torch.eye(x.shape[2])[None, None,...].repeat(x.shape[1], 1, 1, 1)
# compute the diagonal sum with appropriate zero padding
conv_diag_sums = nnf.conv2d(x, f, padding=(x.shape[2]-1,0), groups=x.shape[1])[..., 0]

Note the the result has a slightly different order than the one you computed in the loop:

diag_sums = torch.zeros(1, 64, 255)
for k in range(-127, 128):
    diag_sums[j, :, 127-k] = torch.diagonal(x, offset=k, dim1=-2, dim2=-1).sum(dim=2)

# compare
(conv_diag_sums == diag_sums).all()

results with True - they are the same.

Shai
  • 111,146
  • 38
  • 238
  • 371
1

Shai's answer works, however it looks like it has a lot of multiplications, due to the large size of the kernel. I figured out a way to do this for my use case. It is based on this answer for a similar question in Numpy: https://stackoverflow.com/a/35074207/6636290

I am doing the following:

digitized = np.sum(np.indices(a.shape), axis=0).ravel()
digitized_tensor = torch.Tensor(digitized).int()
a_tensor = torch.Tensor(a)
torch.bincount(digitized_tensor, a_tensor.view(-1))

If I could figure out a way to do this entirely in PyTorch (without Numpy's indices function), this would be great, but this answers the question.

Hristo Vrigazov
  • 1,357
  • 2
  • 12
  • 20
1

The previous answers work, but there is another faster solution using strides (and that only uses Pytorch).

First I'll explain with a matrix as it is easier to understand.

Given you have a matrix M with size (n, n), you can change the matrix strides so that the resulting matrix has M's diagonals as columns. Then you can just sum the column to get your result.

import torch

def sum_all_diagonal_matrix(mat: torch.tensor): 
    n,_ = mat.shape
    zero_mat = torch.zeros((n, n)) # Zero matrix used for padding
    
    mat_padded =  torch.cat((zero_mat, mat, zero_mat), 1) # pads the matrix on left and right
    print(mad_padded)

    mat_strided = mat_padded.as_strided((n, 2*n), (3*n + 1, 1)) # Change the strides
    print(mat_strided)

    sum_diags = torch.sum(mat_strided, 0) # Sums the resulting matrix's columns
    return sum_diags[1:]

X = torch.arange(9).reshape(3,3)
print(X)
# tensor([[0, 1, 2],
#        [3, 4, 5],
#        [6, 7, 8]]) 
print(sum_all_diagonal_matrix(X))
# tensor([ 6., 10., 12.,  6.,  2.])

You can do exactly the same with one more dimension:

def sum_all_diagonal(mat: torch.tensor):  
    k,n,_ = mat.shape
    zero_mat = torch.zeros((k, n, n))
    mat_padded =  torch.cat((zero_mat, mat, zero_mat), 2)
    mat_strided = mat_padded.as_strided((k, n, 2*n), (3*n*n, 3*n + 1, 1))
    sum_diags = torch.sum(mat_strided, 1)
    return sum_diags[:, n:]
Etienne__
  • 11
  • 1