1

I have two 2D Torch Tensors of high dimensions, but of different size. One is a (23462505, 3) tensor and the other one is a (30856, 3) tensor. Let us call them A and B respectively.

My goal is to find the indices of rows appearing in A that also appear in B. With a toy example :

A = torch.tensor([[0,0,1], [0,0,2], [1,0,0], [1,0,1] , [1,0,2], [2,0,0], [2,0,1]])
B = torch.tensor([[0,0,2], [2,0,0], [1,0,2]])

Expected output : [1, 4, 5]

I have tried to work around this issue by converting my torch.tensors into np.array(). It does work on smaller datasets but for computational reasons, I would like to find a way to do it (in PyTorch or Numpy) in an efficient way.

My code:

# ent2id is a dict linking each Knowledge Graph entity (string) to a unique id
entity_ids = torch.arange(end=len(ent2id), device='cpu').unsqueeze(0)
# unpacking heads, relations and tails from the test dataset
heads, relations, tails = X_test[:,0], X_test[:,1], X_test[:,2]
all_entities = entity_ids.repeat(test.shape[0], 1)
heads = heads.reshape(-1, 1).repeat(1, all_entities.size()[1])
relations = relations.reshape(-1, 1).repeat(1, all_entities.size()[1])
tails = tails.reshape(-1, 1).repeat(1, all_entities.size()[1])
# concatenate train and valid triplets 
triplets_trainvalid = torch.cat((X_train, X_val),0)
# generate all possible combinations of triplets 
# by taking the original triplets in test set and replacing their heads
tmp_head_triplets = torch.stack((all_entities, relations, tails), dim=2).reshape(-1, 3)
# Same for tails
tmp_tail_triplets = torch.stack((heads, relations, all_entities), dim=2).reshape(-1, 3)

# For loop (very inefficient) to find the indices of triplets in 'tmp_head_triplets' that already occur in either train or validation set
# 1. For heads
idx = 0
true_idx_heads = []
tmp_head_triplets_lst = tmp_head_triplets.tolist()
triplets_trainvalid_lst = triplets_trainvalid.tolist()
for triple in tmp_head_triplets_lst :
    if triple in triplets_trainvalid_lst :
        true_idx_heads.append(idx)
    idx += 1

# 2. For Tails
idx = 0
true_idx_tails = []
tmp_tail_triplets_lst = tmp_tail_triplets.tolist()
for triple in tmp_tail_triplets_lst :
    if triple in triplets_trainvalid_lst :
        true_idx_tails.append(idx)
    idx += 1

Example (for heads):

triplets_trainvalid = torch.tensor([[1,2,4],
                                   [2,2,4],
                                   [1,3,4],
                                   [2,1,4],
                                   [3,2,1]])

tmp_head_triplets = torch.tensor([[1,1,1],
                                 [2,1,1],
                                 [3,1,1],
                                 [4,1,1],
                                 [1,1,2],
                                 [2,1,2],
                                 [3,1,2],
                                 [4,1,2],
                                 [1,2,4],
                                 [2,2,4],
                                 [3,2,4],
                                 [4,2,4]])

idx = 0
true_idx_heads = []
tmp_head_triplets_lst = tmp_head_triplets.tolist()
triplets_trainvalid_lst = triplets_trainvalid.tolist()
for triple in tmp_head_triplets_lst :
    if triple in triplets_trainvalid_lst :
        true_idx_heads.append(idx)
    idx += 1

print(true_idx_heads) # [8, 9]
elkmyr
  • 11
  • 2
  • I did not understand the backup explanation on which part gets assigned `+INF`. Is it the value in `A`, or `B`? From my understanding, it seems unnecessary to actually maintain an index, and rather simplify it to a boolean "is it contained or not" problem. But maybe you need to maintain the full index? Also, is the order important for this problem, or could one "sort" the vectors before processing? – dennlinger May 09 '22 at 09:08
  • Hi @dennlinger, thank you for your reply. Indeed, it is better maintaining the full index, more precisely the indices of triplets in 'tmp_head_triplets_lst' (resp. 'tmp_tail_triplets_lst') that also occur in 'triplets_trainvalid'. No, I do not think the order is important there, as long as I am able to find back the generated 'fake' triplets that are actually true (i.e. already occurring in either the train or validation set) – elkmyr May 09 '22 at 10:59

0 Answers0