1

e.g. I have a tensor:

import tensorflow.compat.v1 as tf
import numpy as np
a = tf.constant(np.array([[1,2,3,4,5],
                          [2,2,4,5,6],
                          [3,4,3,6,7],
                          [4,5,6,4,8],
                          [5,6,7,8,5]))

It's symmetric. Now I only want to see the part where abs(i-j)>s, where i, j denote the row and col index, s is a para.

It equals to j - i >s for symmerty.

So if set s = 2, I want to convert a to:

        tf.constant(np.array([[0,0,0,4,5],
                              [0,0,0,0,6],
                              [0,0,0,0,0],
                              [0,0,0,0,0],
                              [0,0,0,0,0]))

Is there any convince way to do this in tf1.x? TX!

Quail Wwk
  • 25
  • 4
  • Does this answer your question? [Extract upper or lower triangular part of a numpy matrix](https://stackoverflow.com/questions/8905501/extract-upper-or-lower-triangular-part-of-a-numpy-matrix) – Joe Jun 18 '20 at 17:42
  • https://stackoverflow.com/questions/53378148/how-to-find-the-sum-of-elements-above-and-below-the-diagonal-of-a-matrix-in-pyth this shows how to set the diagonal to zero. And see here https://numpy.org/doc/stable/reference/generated/numpy.tril_indices.html – Joe Jun 18 '20 at 17:42

1 Answers1

1

You can do that like this:

import tensorflow.compat.v1 as tf
import numpy as np

a = tf.constant(np.array([[1, 2, 3, 4, 5],
                          [2, 2, 4, 5, 6],
                          [3, 4, 3, 6, 7],
                          [4, 5, 6, 4, 8],
                          [5, 6, 7, 8, 5]]))
s = 2
shape = tf.shape(a)
i, j = tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]), indexing='ij')
mask = tf.math.abs(i - j) > s
result = a * tf.dtypes.cast(mask, a.dtype)
tf.print(result)
# [[0 0 0 4 5]
#  [0 0 0 0 6]
#  [0 0 0 0 0]
#  [4 0 0 0 0]
#  [5 6 0 0 0]]

The result is different from what you show, but it is what corresponds to the formula abs(i - j) > s. If you only want the upper part, do instead:

mask = j - i > s
jdehesa
  • 58,456
  • 7
  • 77
  • 121