Search code examples
pythonnumpyerror-handlingjax

Error message in Python with differentiation


I am computing these derivatives using the Montecarlo approach for a generic call option. I am interested in this combined derivative (with respect to both S and Sigma). Doing this with the algorithmic differentiation, I get an error that can be seen at the end of the page. What could be a possible solution? Just to explain something regarding the code, I am going to attach the formula used to compute the "X" in the code below:

enter image description here

from jax import jit, grad, vmap
import jax.numpy as jnp
from jax import random
Underlying_asset = jnp.linspace(1.1,1.4,100)
volatilities = jnp.linspace(0.5,0.6,100)
def second_derivative_mc(S,vol):
    N = 100
    j,T,q,r,k = 10000,1.,0,0,1.
    S0 = jnp.array([S]).T #(Nx1) vector underlying asset
    C = jnp.identity(N)*vol    #matrix of volatilities with 0 outside diagonal 
    e = jnp.array([jnp.full(j,1.)])#(1xj) vector of "1"
    Rand = np.random.RandomState()
    Rand.seed(10)
    U= Rand.normal(0,1,(N,j)) #Random number for Brownian Motion
    sigma2 = jnp.array([vol**2]).T #Vector of variance Nx1

    first = jnp.dot(sigma2,e) #First part equation
    second = jnp.dot(C,U)     #Second part equation

    X = -0.5*first+jnp.sqrt(T)*second

    St = jnp.exp(X)*S0

    P = jnp.maximum(St-k,0)
    payoff = jnp.average(P, axis=-1)*jnp.exp(-q*T)
    return payoff 


greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0)(Underlying_asset,volatilities)

This is the error message:

> UnfilteredStackTrace                      Traceback (most recent call
> last) <ipython-input-78-0cc1da97ae0c> in <module>()
>      25 
> ---> 26 greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)
> 
> 18 frames UnfilteredStackTrace: TypeError: Gradient only defined for
> scalar-output functions. Output had shape: (100,).

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

> TypeError                                 Traceback (most recent call
> last) /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in
> _check_scalar(x)
>     894     if isinstance(aval, ShapedArray):
>     895       if aval.shape != ():
> --> 896         raise TypeError(msg(f"had shape: {aval.shape}"))
>     897     else:
>     898       raise TypeError(msg(f"had abstract value {aval}"))

> TypeError: Gradient only defined for scalar-output functions. Output had shape: (100,).

Solution

  • As the error message indicates, gradients can only be computed for functions that return a scalar. Your function returns a vector:

    print(len(second_derivative_mc(1.1, 0.5)))
    # 100
    

    For vector-valued functions, you can compute the jacobian (which is similar to a multi-dimensional gradient). Is this what you had in mind?

    from jax import jacobian
    greek = vmap(jacobian(jacobian(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)
    

    Also, this is not what you asked about, but the function above will probably not work as you intend even if you solve the issue in the question. Numpy RandomState objects are stateful, and thus will generally not work correctly with jax transforms like grad, jit, vmap, etc., which require side-effect-free code (see Stateful Computations In JAX). You might try using jax.random instead; see JAX: Random Numbers for more information.