I would like to use the package Oryx to invert an affine transformation written in JAX. The transformation maps x->y
and depends on a set of adjustable parameters (which I call params
). Specifically, the affine transformation is defined as:
import jax.numpy as jnp
def affine(params, x):
return x * params['scale'] + params['shift']
params = dict(scale=1.5, shift=-1.)
x_in = jnp.array(3.)
y_out = affine(params, x_in)
I would like to invert affine
wrt to input x
as a function of params
. Oryx has a function oryx.core.inverse
to invert JAX functions. However, inverting a function with parameters, like this:
import oryx
oryx.core.inverse(affine)(params, y_out)
doesn't work (AssertionError: length mismatch: [1, 3]
), presumably because inverse
doesn't know that I want to invert y_out
but not params
.
What is the most elegant way to solve this problem for all possible values (i.e., as a function) of params
using oryx.core.inverse
?
I find the inverse docs not very illuminating.
Update:
Jakevdp gave an excellent suggestion for a given set of params
. I've clarified the question to indicate that I am wondering how to define the inverse as a function of params
.