1

After profiling the following load data function, I've realized the following lines are a major bottleneck:

 dist_1 = dist[random_labels, :][:, random_labels]  
 dist_2 = dist[other_random_labels, :][:, other_random_labels]

where the size of dist is 6000,6000 and the random labels is of length 5000.
I'm trying to use np.take but

np.take(dist_1,[random_labels,random_labels]) == dist_1[random_labels, :][:, random_labels]

is False. where the dimention of np.take(dist_1,[random_labels,random_labels]) is (2,5000)
Is there an efficient way of doing this in numpy?
edit: this is the closest I've got:

 dist_1 = np.take(np.take(dist, random_labels, axis=0), random_labels, axis=1)
DsCpp
  • 2,259
  • 3
  • 18
  • 46

0 Answers0