Assume a 2*X(always 2 rows) pytorch tensor:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
torch.unique(A, dim=1)
will return:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
But I also need the indices of every unique elements where they firstly appear in original input. In this case, indices should be like:
tensor([0, 1, 2, 3, 4, 6])
# Explanation
# A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
# [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
# (0) (1) (2) (3) (4) (6)
It's complex for me because the second row of tensor A
may not be nicely sorted:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
^ ^
Is there a simple and efficient method to get the desired indices?
P.S. It may be useful that the first row of the tensor is always in ascending order.