That is a good question I stumbled over a couple of times myself. The simplest answer is that there are no guarantees whatsoever that torch.argmax
(or torch.max(x, dim=k)
, which also returns indices when dim is specified) will return the same index consistently. Instead, it will return any valid index to the argmax value, possibly randomly. As this thread in the official forum discusses, this is considered to be desired behavior. (I know that there is another thread I read a while ago that makes this more explicit, but I cannot find it again).
Having said that, as this behavior was unacceptable to my usecase, I wrote the following functions that will find the left and rightmost indices (be aware that condition
is a function-object you pass in):
def __consistent_args(input, condition, indices):
assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
return torch.argmax(mask, dim=1)
def consistent_find_leftmost(input, condition):
indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
def consistent_find_rightmost(input, condition):
indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)
# will return:
# tensor([6])
Hope they will help! (Oh, and please let me know if you have a better implementation that does the same)