Here is my code that returns a value if you give it a probability between 0 and 1 (it is an inverseCDF function).
import jax.numpy as jnp
from jax import jit, vmap, lax
from jaxopt import Bisection
def find_y(M, a1, a2, a3):
"""
Finds the value of y that corresponds to a given value of M(y), using the bisection method implemented with JAX.
Parameters:
M (float): The desired value of M(y).
a1 (float): The value of coefficient a1.
a2 (float): The value of coefficient a2.
a3 (float): The value of coefficient a3.
Returns:
float: The value of y that corresponds to the given value of M(y).
"""
# Define a function that returns the value of M(y) for a given y
@jit
def M_fn(y):
eps = 1e-8 # A small epsilon to avoid taking the log of a negative number
return a1 + a2 * jnp.log(y / (1 - y + eps)) + a3 * (y - 0.5) * jnp.log(y / (1 - y + eps))
# Define a function that returns the difference between M(y) and M
@jit
def f(y):
return M_fn(y) - M
# Set the bracketing interval for the root-finding function
interval = (1e-7, 1 - 1e-7)
# Use the bisection function to find the root
y = Bisection(f, *interval).run().params
# Return the value of y
return y
## test the algorithm
a1 = 16.
a2 = 3.396
a3 = 0.0
y = find_y(16, a1, a2, a3)
print(y)
I would like to pass an array for argument M
instead of a scalar, but no matter what I try, I get an error (usually about some Boolean trace). Any ideas? Thanks!!
You can do this with jax.vmap
, as long as you set check_bracket=False
in Bisection
(see here):
y = Bisection(f, *interval, check_bracket=False).run().params
With that change to your function, you can pass a vector of values for M
like this:
import jax
M = jnp.array([4, 8, 16, 32])
result = jax.vmap(find_y, in_axes=(0, None, None, None))(M, a1, a2, a3)
print(result)
[0.02837202 0.08661279 0.5 0.99108815]