When working with the tri...
set of functions it can be useful to examine the source code. They are all python, and based on np.tri
.
Make a small sample array - to illustrate and verify the answer:
In [205]: arr = np.arange(18).reshape(2,3,3) # arange(1,19) might be better
In [206]: arr
Out[206]:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
tril
sets the upper triangle values to 0. It works in this case, but application to 3d arrays is not documented.
In [207]: np.tril(arr)
Out[207]:
array([[[ 0, 0, 0],
[ 3, 4, 0],
[ 6, 7, 8]],
[[ 9, 0, 0],
[12, 13, 0],
[15, 16, 17]]])
But in the code if first constructs a boolean mask from the last 2 dimensions:
In [208]: mask = np.tri(*arr.shape[-2:], dtype=bool)
In [209]: mask
Out[209]:
array([[ True, False, False],
[ True, True, False],
[ True, True, True]])
and uses np.where
to set some values to 0. This works in the 3d case by broadcasting. mask
and arr
match on the last 2 dimensions, so mask
can broadcast
to match:
In [210]: np.where(mask, arr, 0)
Out[210]:
array([[[ 0, 0, 0],
[ 3, 4, 0],
[ 6, 7, 8]],
[[ 9, 0, 0],
[12, 13, 0],
[15, 16, 17]]])
Your tril_indices
is just the indices of this mask:
In [217]: np.nonzero(mask) # aka np.where
Out[217]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [218]: np.tril_indices(3)
Out[218]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
They can't be used directly to index arr
:
In [220]: arr[np.tril_indices(3)].shape
Traceback (most recent call last):
File "<ipython-input-220-e26dc1f514cc>", line 1, in <module>
arr[np.tril_indices(3)].shape
IndexError: index 2 is out of bounds for axis 0 with size 2
In [221]: arr[:,np.tril_indices(3)].shape
Out[221]: (2, 2, 6, 3)
But unpacking the two indexing arrays:
In [222]: I,J = np.tril_indices(3)
In [223]: I,J
Out[223]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [224]: arr[:,I,J]
Out[224]:
array([[ 0, 3, 4, 6, 7, 8],
[ 9, 12, 13, 15, 16, 17]])
The boolean mask can also be used directly:
In [226]: arr[:,mask]
Out[226]:
array([[ 0, 3, 4, 6, 7, 8],
[ 9, 12, 13, 15, 16, 17]])
The base np.tri
works by simply doing an outer >= on indices
In [231]: m = np.greater_equal.outer(np.arange(3),np.arange(3))
In [232]: m
Out[232]:
array([[ True, False, False],
[ True, True, False],
[ True, True, True]])
In [234]: np.arange(3)[:,None]>=np.arange(3)
Out[234]:
array([[ True, False, False],
[ True, True, False],
[ True, True, True]])