Search code examples
pythonjaxoryx

Oryx: How to `inverse` a JAX function with additional parameters


I would like to use the package Oryx to invert an affine transformation written in JAX. The transformation maps x->y and depends on a set of adjustable parameters (which I call params). Specifically, the affine transformation is defined as:

import jax.numpy as jnp

def affine(params, x):
  return x * params['scale'] + params['shift']

params = dict(scale=1.5, shift=-1.)
x_in = jnp.array(3.)
y_out = affine(params, x_in)

I would like to invert affine wrt to input x as a function of params. Oryx has a function oryx.core.inverse to invert JAX functions. However, inverting a function with parameters, like this:

import oryx

oryx.core.inverse(affine)(params, y_out)

doesn't work (AssertionError: length mismatch: [1, 3]), presumably because inverse doesn't know that I want to invert y_out but not params. What is the most elegant way to solve this problem for all possible values (i.e., as a function) of params using oryx.core.inverse? I find the inverse docs not very illuminating.

Update: Jakevdp gave an excellent suggestion for a given set of params. I've clarified the question to indicate that I am wondering how to define the inverse as a function of params.


Solution

  • You can do this by closing over the static parameters, for example using partial:

    from functools import partial
    x = oryx.core.inverse(partial(affine, params))(y_out)
    
    print(x)
    # 3.0
    

    Edit: if you want a single inverted function to work for multiple values of params, you will have to return params in the output (otherwise, there's no way from a single output value to infer all three inputs). It might look something like this:

    def affine(params, x):
      return params, x * params['scale'] + params['shift']
    
    params = dict(scale=1.5, shift=-1.)
    x_in = jnp.array(3.)
    _, y_out = affine(params, x_in)
    
    _, x = oryx.core.inverse(affine)(params, y_out)
    print(x)
    # 3.0