0

i'm working on big distance matrix (10-80k row ; 3k cols) and i want to get custom pairwise distance on that matrix ; and do it fast. I have trying with armadillo but with huge data it still "slow" I try with torch with cuda acceleration and with built in euclidean distance that realy so fast (100 times faster). So now i want to make custom pairwise distance like : for pairwise row (a and b): get the standard deviation of ai*bi (where i is cols) for example :

my_mat:
   |1    |2    |3    |4
a  |5    |3    |0    |4
b  |1    |6    |2    |3

a//b dist = std(5*1,3*6,0*2,4*3)
          = std(5,18,0,12)
          = 7.889867

i think about : start with my two dimension (N,M) tensor (my_mat) create a new tensor with 3 dimension (N,N,P) and in P dimension store a "list" with each pairwise product by cols :

3_dim_tens :

   |a                      |b
a  |Pdim(5*5,3*3,0*0,4*4)  |Pdim(5*1,3*6,0*2,4*3)
b  |Pdim(5*1,3*6,0*2,4*3)  |Pdim(5*5,3*3,0*0,4*4)

then if i reduce Pdim by std() i will have 2 dims (N,N) pairwise matrix with my custom distance. (typically is like matmul my_mat * t(my_mat) but with std in place of addition)

is it possible to do this with torch or is there another way for custom pairwise distance?

Ludo Vic
  • 17
  • 7

1 Answers1

1

I think the most intuitive way is using einsum for this:

import torch
a = torch.tensor([[5.0, 3, 0, 4],[1, 6, 2, 3]])
b = torch.einsum('ij,kj->ikj', a, a).std(dim=2)
print(b)
flawr
  • 10,814
  • 3
  • 41
  • 71
  • Thaanks that's work well! exepting the dim to reduce for std is 3 in my case but not important. just one more question, how to remove 0 before apply std? – Ludo Vic Nov 14 '22 at 15:33
  • Not sure if you're aware but the dimensions use 0-based indexing, so `dim=2` reduces the third dimension. Please consider asking that as a separate question, but in short: You cannot just "remove" zeros from a numpy array, as you'd get a jagged array, which is not supported in numpy. But for `std` you could write your own function, but that is beyond the scope of this question. – flawr Nov 14 '22 at 15:38
  • i'm using R (sorry sorry sorry...) and dimension indexing seems to start at 1. stay tuned for the new question! thx – Ludo Vic Nov 15 '22 at 07:56
  • No worries, MATLAB and Octave also use 1-based indexing, so I completely understand the problem:) – flawr Nov 15 '22 at 10:19