Search code examples
pythonscipyjax

Can't calculate matrix exponential in python


I want to calculate:

from jax.scipy.linalg import expm
import jax.numpy as jnp
from functools import lru_cache, reduce

num_qubits=2
theta = jnp.asarray(np.pi*np.random.random((15,2,2,2,2,2,2,2,2)))

def pauli_matrix(num_qubits):
    _pauli_matrices = jnp.array(
    [[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]]
    )
    return reduce(jnp.kron, (_pauli_matrices for _ in range(num_qubits)))[1:]

def SpecialUnitary(num_qubits,theta):
    assert theta.shape[0] == 15
    A = jnp.tensordot(theta, pauli_matrix(num_qubits), axes=[[0], [0]])
    print(f'{A.shape= }{pauli_matrix(num_qubits).shape=}{theta.shape=}')
    return expm(1j*A/2)

SpecialUnitary(num_qubits,theta)

Shapes: A.shape= (2, 2, 2, 2, 2, 2, 2, 2, 4, 4)pauli_matrix(num_qubits).shape=(15, 4, 4)theta.shape=(15, 2, 2, 2, 2, 2, 2, 2, 2) Error: ValueError: expected A to be a square matrix

I'm stuck because the documentation says that the expm is calculated on the last two axes, which must be square, which is done.


Solution

  • Batched expm is supported in recent JAX versions, and you should find that this works fine in JAX v0.4.7 or newer:

    import jax.numpy as jnp
    import jax.scipy.linalg
    
    X = jnp.arange(128.0).reshape(2, 2, 2, 4, 4)
    
    result = jax.scipy.linalg.expm(X)
    print(result.shape)
    # (2, 2, 2, 4, 4)
    

    If for some reason you must use an older JAX version, you can work around this by using jax.numpy.vectorize. For example:

    expm = jnp.vectorize(jax.scipy.linalg.expm, signature='(n,n)->(n,n)')
    
    result = expm(X)
    print(result.shape)
    # (2, 2, 2, 4, 4)