1

For example, there is a PyTorch matrix A:

A = tensor([[3,2,1],[1,0,2],[2,2,0]])

I need to replace 0 with 1 on the diagonal, so the result should be:

tensor([[3,2,1],[1,1,2],[2,2,1]])
iacob
  • 20,084
  • 6
  • 92
  • 119
Sirui Li
  • 245
  • 3
  • 15
  • Related: https://stackoverflow.com/q/49512313 https://stackoverflow.com/q/49429147 https://stackoverflow.com/q/65712349 – iacob Mar 23 '21 at 14:02

2 Answers2

2

You can use torch's inbuilt diagonal functions to replace diagonal elements like so:

mask = A.diagonal() == 0
A += torch.diag(mask)
>>> A
tensor([[3, 2, 1],
        [1, 1, 2],
        [2, 2, 1]])

If you want to replace 0's with another value, change mask to mask * replace_value.

iacob
  • 20,084
  • 6
  • 92
  • 119
0

You can use vector indexing to extract the diagonal, process it, and then put it back into your original matrix:

N=10
a = torch.randint(0,N,[N,N])

#tensor([[0, 9, 6, 6, 9, 9, 3, 1, 8, 4],
#    [8, 1, 6, 8, 5, 8, 7, 8, 1, 4],
#    [1, 9, 8, 4, 7, 0, 2, 9, 6, 2],
#    [9, 5, 9, 6, 7, 1, 4, 0, 2, 6],
#    [1, 2, 8, 0, 9, 0, 4, 3, 9, 9],
#    [1, 4, 6, 9, 6, 5, 1, 2, 0, 7],
#    [4, 8, 1, 3, 1, 6, 1, 3, 5, 6],
#    [3, 8, 9, 9, 1, 3, 0, 9, 6, 6],
#    [7, 4, 3, 0, 3, 5, 6, 6, 9, 2],
#    [3, 1, 0, 8, 3, 5, 6, 6, 5, 5]])

diag = a[range(N),range(N)] #index (1,1), (2,2), ... etc
diag[diag==0] = 1 # set according to your condition
a[range(N),range(N)] = diag #return the diagonal to its place

#tensor([[1, 9, 6, 6, 9, 9, 3, 1, 8, 4],
#    [8, 1, 6, 8, 5, 8, 7, 8, 1, 4],
#    [1, 9, 8, 4, 7, 0, 2, 9, 6, 2],
#    [9, 5, 9, 6, 7, 1, 4, 0, 2, 6],
#    [1, 2, 8, 0, 9, 0, 4, 3, 9, 9],
#    [1, 4, 6, 9, 6, 5, 1, 2, 0, 7],
#    [4, 8, 1, 3, 1, 6, 1, 3, 5, 6],
#    [3, 8, 9, 9, 1, 3, 0, 9, 6, 6],
#    [7, 4, 3, 0, 3, 5, 6, 6, 9, 2],
#    [3, 1, 0, 8, 3, 5, 6, 6, 5, 5]])
jhso
  • 3,103
  • 1
  • 5
  • 13