After profiling the following load data
function, I've realized the following lines are a major bottleneck:
dist_1 = dist[random_labels, :][:, random_labels]
dist_2 = dist[other_random_labels, :][:, other_random_labels]
where the size of dist
is 6000,6000
and the random labels is of length 5000
.
I'm trying to use np.take
but
np.take(dist_1,[random_labels,random_labels]) == dist_1[random_labels, :][:, random_labels]
is False
.
where the dimention of np.take(dist_1,[random_labels,random_labels])
is (2,5000)
Is there an efficient way of doing this in numpy?
edit: this is the closest I've got:
dist_1 = np.take(np.take(dist, random_labels, axis=0), random_labels, axis=1)