I have two 2D Torch Tensors of high dimensions, but of different size. One is a (23462505, 3) tensor and the other one is a (30856, 3) tensor. Let us call them A and B respectively.
My goal is to find the indices of rows appearing in A that also appear in B. With a toy example :
A = torch.tensor([[0,0,1], [0,0,2], [1,0,0], [1,0,1] , [1,0,2], [2,0,0], [2,0,1]])
B = torch.tensor([[0,0,2], [2,0,0], [1,0,2]])
Expected output : [1, 4, 5]
I have tried to work around this issue by converting my torch.tensors into np.array(). It does work on smaller datasets but for computational reasons, I would like to find a way to do it (in PyTorch or Numpy) in an efficient way.
My code:
# ent2id is a dict linking each Knowledge Graph entity (string) to a unique id
entity_ids = torch.arange(end=len(ent2id), device='cpu').unsqueeze(0)
# unpacking heads, relations and tails from the test dataset
heads, relations, tails = X_test[:,0], X_test[:,1], X_test[:,2]
all_entities = entity_ids.repeat(test.shape[0], 1)
heads = heads.reshape(-1, 1).repeat(1, all_entities.size()[1])
relations = relations.reshape(-1, 1).repeat(1, all_entities.size()[1])
tails = tails.reshape(-1, 1).repeat(1, all_entities.size()[1])
# concatenate train and valid triplets
triplets_trainvalid = torch.cat((X_train, X_val),0)
# generate all possible combinations of triplets
# by taking the original triplets in test set and replacing their heads
tmp_head_triplets = torch.stack((all_entities, relations, tails), dim=2).reshape(-1, 3)
# Same for tails
tmp_tail_triplets = torch.stack((heads, relations, all_entities), dim=2).reshape(-1, 3)
# For loop (very inefficient) to find the indices of triplets in 'tmp_head_triplets' that already occur in either train or validation set
# 1. For heads
idx = 0
true_idx_heads = []
tmp_head_triplets_lst = tmp_head_triplets.tolist()
triplets_trainvalid_lst = triplets_trainvalid.tolist()
for triple in tmp_head_triplets_lst :
if triple in triplets_trainvalid_lst :
true_idx_heads.append(idx)
idx += 1
# 2. For Tails
idx = 0
true_idx_tails = []
tmp_tail_triplets_lst = tmp_tail_triplets.tolist()
for triple in tmp_tail_triplets_lst :
if triple in triplets_trainvalid_lst :
true_idx_tails.append(idx)
idx += 1
Example (for heads):
triplets_trainvalid = torch.tensor([[1,2,4],
[2,2,4],
[1,3,4],
[2,1,4],
[3,2,1]])
tmp_head_triplets = torch.tensor([[1,1,1],
[2,1,1],
[3,1,1],
[4,1,1],
[1,1,2],
[2,1,2],
[3,1,2],
[4,1,2],
[1,2,4],
[2,2,4],
[3,2,4],
[4,2,4]])
idx = 0
true_idx_heads = []
tmp_head_triplets_lst = tmp_head_triplets.tolist()
triplets_trainvalid_lst = triplets_trainvalid.tolist()
for triple in tmp_head_triplets_lst :
if triple in triplets_trainvalid_lst :
true_idx_heads.append(idx)
idx += 1
print(true_idx_heads) # [8, 9]