3

Is there a simple way to zero the diagonal of a PyTorch tensor?

For example I have:

tensor([[2.7183, 0.4005, 2.7183, 0.5236],
        [0.4005, 2.7183, 0.4004, 1.3469],
        [2.7183, 0.4004, 2.7183, 0.5239],
        [0.5236, 1.3469, 0.5239, 2.7183]])

And I want to get:

tensor([[0.0000, 0.4005, 2.7183, 0.5236],
        [0.4005, 0.0000, 0.4004, 1.3469],
        [2.7183, 0.4004, 0.0000, 0.5239],
        [0.5236, 1.3469, 0.5239, 0.0000]])
iacob
  • 20,084
  • 6
  • 92
  • 119
gus
  • 71
  • 4
  • https://stackoverflow.com/questions/49512313/masking-diagonal-to-a-specific-value-with-pytorch-tensors/66760701#66760701 – iacob Mar 23 '21 at 10:46

5 Answers5

5

I believe the simplest would be to use torch.diagonal:

z = torch.randn(4,4)
torch.diagonal(z, 0).zero_()
print(z)
>>> tensor([[ 0.0000, -0.6211,  0.1120,  0.8362],
            [-0.1043,  0.0000,  0.1770,  0.4197],
            [ 0.7211,  0.1138,  0.0000, -0.7486], 
            [-0.5434, -0.8265, -0.2436,  0.0000]])

This way, the code is perfectly explicit, and you delegate the performance to pytorch's built in functions.

trialNerror
  • 3,255
  • 7
  • 18
5

You can simply use:

x.fill_diagonal_(0)
iacob
  • 20,084
  • 6
  • 92
  • 119
1

Yes, there are a couple ways to do that, simplest one would be to go directly:

import torch

tensor = torch.rand(4, 4)
tensor[torch.arange(tensor.shape[0]), torch.arange(tensor.shape[1])] = 0

This one broadcasts 0 value across all pairs, e.g. (0, 0), (1, 1), ..., (n, n)

Another way would be (readability is debatable) to use the inverse of torch.eye like this:

tensor = torch.rand(4, 4)
tensor *= ~(torch.eye(*tensor.shape).bool())

This one creates additional matrix and does way more operations, hence I'd stick with the first version.

Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83
1

As an alternative to indexing with two tensors separately, you could achieve this using a combination of torch.repeat, and torch.split, taking advantage of the fact the latter returns a tuple:

>>> x[torch.arange(len(x)).repeat(2).split(len(x))] = 0
>>> x
tensor([[0.0000, 0.4005, 2.7183, 0.5236],
        [0.4005, 0.0000, 0.4004, 1.3469],
        [2.7183, 0.4004, 0.0000, 0.5239],
        [0.5236, 1.3469, 0.5239, 0.0000]])
Ivan
  • 34,531
  • 8
  • 55
  • 100
1

Here's another way:

x.flatten()[::(x.shape[-1]+1)] = 0
iacob
  • 20,084
  • 6
  • 92
  • 119
Shai
  • 111,146
  • 38
  • 238
  • 371