1

I would like to find the permutation parity sign for a given batch of vectors (in Python /Jax).

n = jnp.array([[[0., 0., 1., 1.],
               [0., 0., 1., 1.],
               [1., 1., 0., 0.],
               [1., 1., 0., 0.]],

               [[1., 0., 1., 0.],
               [0., 1., 0., 1.],
               [1., 0., 1., 0.],
               [0., 1., 0., 1.]],

               [[0., 1., 1., 0.],
               [1., 0., 0., 1.],
               [1., 0., 0., 1.],
               [0., 1., 1., 0.]]])

sorted_index = jax.vmap(sorted_idx)(n)
sorted_perms = jax.vmap(jax.vmap(sorted_perm, in_axes=(0, 0)), in_axes=(0,0))(n, sorted_index)
parities = jax.vmap(parities)(sorted_index)

I expect the following solution:

sorted_elements= [[[0., 0., 1., 1.],
                  [0., 0., 1., 1.],
                  [0., 0., 1., 1.],
                  [0., 0., 1., 1.]],

                  [[0., 0., 1., 1.],
                  [0., 0., 1., 1.],
                  [0., 0., 1., 1.],
                  [0., 0., 1., 1.]],

                  [[0., 0., 1., 1.],
                  [0., 0., 1., 1.],
                  [0., 0., 1., 1.],
                  [0., 0., 1., 1.]]]

parities = [[1, 1, 1, 1],
            [-1, -1, -1, -1],
            [1, 1, 1, 1]]

I tried the following:

# sort the array and return the arg_sort indices
def sorted_idx(permutations): 
 sort_idx = jnp.argsort(permutations)  
 return sort_idx

# sort the permutations (vectors) given the sorted_indices
def sorted_perm(permutations, sort_idx):
   perm = permutations[sort_idx]
   return perm


# Calculate the permutation cycle, from which we compute the permutation parity 

@jax.vmap
def parities(sort_idx):
  length = len(sort_idx)
  elements_seen = jnp.zeros(length)
  cycles = 0

  for index in range(length):
      if elements_seen[index] == True:
          continue
      cycles += 1
      current = index
      if elements_seen[current] == False:            
          elements_seen.at[current].set(True)
          current = sort_idx[current]
        
  is_even = (length - cycles) % 2 == 0
  return +1 if is_even else -1

But I get the following: parities= [[1 1 1 1], [1 1 1 1], [1 1 1 1]]

I get for each permutation vector a parity factor of 1, which is wrong....

relaxon
  • 141
  • 6
  • I don't think permutation parity is well-defined when two (or more) elements of the array are equal. – Joffan May 16 '22 at 21:24
  • That is why I put in my 'parities' function the sorted indices (which is something like [[0,1,2,3],[0,2,3,1],...etc] as an input... – relaxon May 16 '22 at 21:49

1 Answers1

1

The reason your routine doesn't work is because you're attempting to vmap over Python control flow, and this must be done very carefully (See JAX Sharp Bits: Control Flow). I suspect it would be a bit complicated to try to construct your iterative parity approach in terms of jax.lax control flow operators, but there might be another way.

The parity of a permutation is related to the determinant of its cycle matrix, and the jacobian of a sort happens to be equivalent to that cycle matrix, so you could (ab)use JAX's automatic differentiation of the sort operator to compute the parities very concisely:

def compute_parity(p):
  return jnp.linalg.det(jax.jacobian(jnp.sort)(p.astype(float))).astype(int)

sorted_index = jnp.argsort(n, axis=-1)
parities = jax.vmap(jax.vmap(compute_parity))(sorted_index)

print(parities)
# [[ 1  1  1  1]
#  [-1 -1 -1 -1]
#  [ 1  1  1  1]]

This does end up being O[N^3] where N is the length of the permutations, but due to the nature of XLA computations, particularly on accelerators like GPU, the vectorized approach will likely be more efficient than an iterative approach for reasonably-sized N.

Also note that there's no reason to compute the sorted_index with this implementation; you could call compute_parity directly on your array n instead.

jakevdp
  • 77,104
  • 11
  • 125
  • 160