Sorry for the short and probably over-simplified example. I fear a bigger one would be much more difficult to visualize. But I hope this suits your purpose.
My solution may seem a little complicated but it's fully vectorized and includes no explicit loops.
Here's what I would do:
import torch
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
output = torch.zeros((batchSize, 2))
composition_matrix = torch.randint(0, 10, (14, 2))
# compair all vectors in composition_matrix to all vectors in pred
comparisons = (composition_matrix.unsqueeze(0) == pred.unsqueeze(1))
comparisons = comparisons.all(2)
# form an index array the shape of the comparisons array
comparison_idxs = torch.arange(comparisons.shape[1])
comparison_idxs = comparison_idxs.repeat(batchSize).reshape(*comparisons.shape)
# multipy the comparisons array by the index array
where_result = (comparison_idxs*comparisons)
# replace invalind zeros with the maximal value in each sample
batch_idxs = torch.arange(comparisons.shape[0])
batch_idxs = batch_idxs.repeat(comparisons.shape[1])
batch_idxs = batch_idxs.reshape(comparisons.shape[1], comparisons.shape[0]).T
maxima = where_result.max(1).values[batch_idxs]
maxima_vecor = maxima[(1-comparisons.int()).bool()]
where_result[(1-comparisons.int()).bool()] = maxima_vecor
vectorized_output = where_result.min(1)[0]
output = torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
output:
composition_matrix =
tensor([[6, 8],
[4, 3],
[6, 9],
[1, 4],
[4, 1],
[9, 9],
[9, 0],
[1, 2],
[3, 0],
[5, 5],
[2, 9],
[1, 8],
[8, 3],
[6, 9]])
pred =
tensor([[4, 9],
[3, 0],
[3, 9],
[7, 3],
[7, 3],
[1, 6],
[6, 9],
[8, 6]])
output =
tensor([0., 8., 0., 0., 0., 0., 2., 0.])
vectorized_output =
tensor([0, 8, 0, 0, 0, 0, 2, 0])
Some timing results:
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
composition_matrix = torch.randint(0, 10, (14000, 2))
print('timing the vectorized_solution:')
%timeit -n 1000 vectorized_solution(composition_matrix, pred,)
print('timing the loop_solution:')
%timeit -n 1000 loop_solution(composition_matrix, pred,)
output:
timing the vectorized_solution:
1000 loops, best of 5: 137 µs per loop
timing the loop_solution:
1000 loops, best of 5: 1.89 ms per loop