I would like to find the permutation parity sign for a given batch of vectors (in Python /Jax).
n = jnp.array([[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[1., 1., 0., 0.],
[1., 1., 0., 0.]],
[[1., 0., 1., 0.],
[0., 1., 0., 1.],
[1., 0., 1., 0.],
[0., 1., 0., 1.]],
[[0., 1., 1., 0.],
[1., 0., 0., 1.],
[1., 0., 0., 1.],
[0., 1., 1., 0.]]])
sorted_index = jax.vmap(sorted_idx)(n)
sorted_perms = jax.vmap(jax.vmap(sorted_perm, in_axes=(0, 0)), in_axes=(0,0))(n, sorted_index)
parities = jax.vmap(parities)(sorted_index)
I expect the following solution:
sorted_elements= [[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]],
[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]],
[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]]]
parities = [[1, 1, 1, 1],
[-1, -1, -1, -1],
[1, 1, 1, 1]]
I tried the following:
# sort the array and return the arg_sort indices
def sorted_idx(permutations):
sort_idx = jnp.argsort(permutations)
return sort_idx
# sort the permutations (vectors) given the sorted_indices
def sorted_perm(permutations, sort_idx):
perm = permutations[sort_idx]
return perm
# Calculate the permutation cycle, from which we compute the permutation parity
@jax.vmap
def parities(sort_idx):
length = len(sort_idx)
elements_seen = jnp.zeros(length)
cycles = 0
for index in range(length):
if elements_seen[index] == True:
continue
cycles += 1
current = index
if elements_seen[current] == False:
elements_seen.at[current].set(True)
current = sort_idx[current]
is_even = (length - cycles) % 2 == 0
return +1 if is_even else -1
But I get the following: parities= [[1 1 1 1], [1 1 1 1], [1 1 1 1]]
I get for each permutation vector a parity factor of 1, which is wrong....