For simplicity, let's say you have a matrix L_1
and want to replace it's diagonal with zeros. You can do this in multiple ways.
Using fill_diagonal_()
:
L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
L_1 = L_1.fill_diagonal_(0.)
Using advanced indexing:
L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
length = len(L_1)
zero_vector = torch.zeros(length, dtype=torch.float32)
L_1[range(length), range(length)] = zero_vector
Using scatter_()
:
L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
diag_idx = torch.arange(len(L_1)).unsqueeze(1)
zero_matrix = torch.zeros(L_1.shape, dtype=torch.float32)
L_1 = L_1.scatter_(1, diag_idx, zero_matrix)
Note that all of the above solutions are in-place operations and will affect the backward pass, because the original value might be needed to compute it. So if you want to keep the backward pass unaffected, meaning to "break the graph" by not recording the change (operation), meaning not computing the gradients in the backward pass corresponding to what you computed in the forward pass, then you can just add the .data
when using advanced indexing or scatter_()
.
Using advanced indexing with .data
:
L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
length = len(L_1)
zero_vector = torch.zeros(length, dtype=torch.float32)
L_1[range(length), range(length)] = zero_vector.data
Using scatter_()
with .data
:
L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
diag_idx = torch.arange(len(L_1)).unsqueeze(1)
zero_matrix = torch.zeros(L_1.shape, dtype=torch.float32)
L_1 = L_1.scatter_(1, diag_idx, zero_matrix.data)
For reference check out this discussion.