0

I would like to use the package Oryx to invert an affine transformation written in JAX. The transformation maps x->y and depends on a set of adjustable parameters (which I call params). Specifically, the affine transformation is defined as:

import jax.numpy as jnp

def affine(params, x):
  return x * params['scale'] + params['shift']

params = dict(scale=1.5, shift=-1.)
x_in = jnp.array(3.)
y_out = affine(params, x_in)

I would like to invert affine wrt to input x as a function of params. Oryx has a function oryx.core.inverse to invert JAX functions. However, inverting a function with parameters, like this:

import oryx

oryx.core.inverse(affine)(params, y_out)

doesn't work (AssertionError: length mismatch: [1, 3]), presumably because inverse doesn't know that I want to invert y_out but not params. What is the most elegant way to solve this problem for all possible values (i.e., as a function) of params using oryx.core.inverse? I find the inverse docs not very illuminating.

Update: Jakevdp gave an excellent suggestion for a given set of params. I've clarified the question to indicate that I am wondering how to define the inverse as a function of params.

Hylke
  • 75
  • 6

1 Answers1

2

You can do this by closing over the static parameters, for example using partial:

from functools import partial
x = oryx.core.inverse(partial(affine, params))(y_out)

print(x)
# 3.0

Edit: if you want a single inverted function to work for multiple values of params, you will have to return params in the output (otherwise, there's no way from a single output value to infer all three inputs). It might look something like this:

def affine(params, x):
  return params, x * params['scale'] + params['shift']

params = dict(scale=1.5, shift=-1.)
x_in = jnp.array(3.)
_, y_out = affine(params, x_in)

_, x = oryx.core.inverse(affine)(params, y_out)
print(x)
# 3.0
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • If, for example, my `params` change every iteration in a loop. Would I then define a separate inverse function for every set of `params` (i.e., every iteration)? Or is there a way to define an inverse as a function of `params`? – Hylke Jun 23 '23 at 12:26
  • Sure, see the edited answer. – jakevdp Jun 23 '23 at 20:35