I want to calculate:
from jax.scipy.linalg import expm
import jax.numpy as jnp
from functools import lru_cache, reduce
num_qubits=2
theta = jnp.asarray(np.pi*np.random.random((15,2,2,2,2,2,2,2,2)))
def pauli_matrix(num_qubits):
_pauli_matrices = jnp.array(
[[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]]
)
return reduce(jnp.kron, (_pauli_matrices for _ in range(num_qubits)))[1:]
def SpecialUnitary(num_qubits,theta):
assert theta.shape[0] == 15
A = jnp.tensordot(theta, pauli_matrix(num_qubits), axes=[[0], [0]])
print(f'{A.shape= }{pauli_matrix(num_qubits).shape=}{theta.shape=}')
return expm(1j*A/2)
SpecialUnitary(num_qubits,theta)
Shapes: A.shape= (2, 2, 2, 2, 2, 2, 2, 2, 4, 4)pauli_matrix(num_qubits).shape=(15, 4, 4)theta.shape=(15, 2, 2, 2, 2, 2, 2, 2, 2)
Error: ValueError: expected A to be a square matrix
I'm stuck because the documentation says that the expm is calculated on the last two axes, which must be square, which is done.
Batched expm
is supported in recent JAX versions, and you should find that this works fine in JAX v0.4.7 or newer:
import jax.numpy as jnp
import jax.scipy.linalg
X = jnp.arange(128.0).reshape(2, 2, 2, 4, 4)
result = jax.scipy.linalg.expm(X)
print(result.shape)
# (2, 2, 2, 4, 4)
If for some reason you must use an older JAX version, you can work around this by using jax.numpy.vectorize
. For example:
expm = jnp.vectorize(jax.scipy.linalg.expm, signature='(n,n)->(n,n)')
result = expm(X)
print(result.shape)
# (2, 2, 2, 4, 4)