I am computing these derivatives using the Montecarlo approach for a generic call option. I am interested in this combined derivative (with respect to both S and Sigma). Doing this with the algorithmic differentiation, I get an error that can be seen at the end of the page. What could be a possible solution? Just to explain something regarding the code, I am going to attach the formula used to compute the "X" in the code below:
from jax import jit, grad, vmap
import jax.numpy as jnp
from jax import random
Underlying_asset = jnp.linspace(1.1,1.4,100)
volatilities = jnp.linspace(0.5,0.6,100)
def second_derivative_mc(S,vol):
N = 100
j,T,q,r,k = 10000,1.,0,0,1.
S0 = jnp.array([S]).T #(Nx1) vector underlying asset
C = jnp.identity(N)*vol #matrix of volatilities with 0 outside diagonal
e = jnp.array([jnp.full(j,1.)])#(1xj) vector of "1"
Rand = np.random.RandomState()
Rand.seed(10)
U= Rand.normal(0,1,(N,j)) #Random number for Brownian Motion
sigma2 = jnp.array([vol**2]).T #Vector of variance Nx1
first = jnp.dot(sigma2,e) #First part equation
second = jnp.dot(C,U) #Second part equation
X = -0.5*first+jnp.sqrt(T)*second
St = jnp.exp(X)*S0
P = jnp.maximum(St-k,0)
payoff = jnp.average(P, axis=-1)*jnp.exp(-q*T)
return payoff
greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0)(Underlying_asset,volatilities)
This is the error message:
> UnfilteredStackTrace Traceback (most recent call
> last) <ipython-input-78-0cc1da97ae0c> in <module>()
> 25
> ---> 26 greek = vmap(grad(grad(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)
>
> 18 frames UnfilteredStackTrace: TypeError: Gradient only defined for
> scalar-output functions. Output had shape: (100,).
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
> TypeError Traceback (most recent call
> last) /usr/local/lib/python3.7/dist-packages/jax/_src/api.py in
> _check_scalar(x)
> 894 if isinstance(aval, ShapedArray):
> 895 if aval.shape != ():
> --> 896 raise TypeError(msg(f"had shape: {aval.shape}"))
> 897 else:
> 898 raise TypeError(msg(f"had abstract value {aval}"))
> TypeError: Gradient only defined for scalar-output functions. Output had shape: (100,).
As the error message indicates, gradients can only be computed for functions that return a scalar. Your function returns a vector:
print(len(second_derivative_mc(1.1, 0.5)))
# 100
For vector-valued functions, you can compute the jacobian (which is similar to a multi-dimensional gradient). Is this what you had in mind?
from jax import jacobian
greek = vmap(jacobian(jacobian(second_derivative_mc, argnums=1), argnums=0))(Underlying_asset,volatilities)
Also, this is not what you asked about, but the function above will probably not work as you intend even if you solve the issue in the question. Numpy RandomState
objects are stateful, and thus will generally not work correctly with jax transforms like grad
, jit
, vmap
, etc., which require side-effect-free code (see Stateful Computations In JAX). You might try using jax.random
instead; see JAX: Random Numbers for more information.