Search code examples
pythonjax

Multiplying chains of matrices in JAX


Suppose I have a vector of parameters p which parameterizes a set of matrices A_1(p), A_2(p),...,A_N(p). I have a computation in which for some list of indices q of length M, I have to compute A_{q_M} * ... * A_{q_2} * A_{q_1} * v for several different q s. Each q has a different length, but crucially doesn't change! What changes, and what I wish to take gradients against is p.

I'm trying to figure out how to convert this to performant JAX. One way to do it is to have some large matrix Q which contains all the different qs on each row, padded out with identity matrices such that each multiplication chain is the same length, and then scan over a function that switch es between N different functions doing matrix-vector multiplications by A_n(p).

However -- I don't particularly like the idea of this padding. Also, since Q here is fixed, is there potentially a smarter way to do this? The distribution of lengths of q s has a very long tail, so Q will be dominated by padding.

EDIT: Here's a (edit 2: functional) minimal example

sigma0 = jnp.eye(2)
sigmax = jnp.array([[0, 1], [1, 0]])
sigmay = jnp.array([[0, -1j], [1j, 0]])
sigmaz = jnp.array([[1, 0], [0, -1]])
sigma = jnp.array([sigmax, sigmay, sigmaz])

def gates_func(params):
    theta = params["theta"]
    epsilon = params["epsilon"]

    n = jnp.array([jnp.cos(theta), 0, jnp.sin(theta)])
    omega = jnp.pi / 2 * (1 + epsilon)
    X90 = expm(-1j * omega * jnp.einsum("i,ijk->jk", n, sigma) / 2)

    return {
        "Z90": expm(-1j * jnp.pi / 2 * sigmaz / 2),
        "X90": X90
    }

def multiply_out(params):
    gate_lists = [["X90", "X90"], ["X90","Z90"], ["Z90", "X90"], ["X90","Z90","X90"]]

    gates = gates_func(params)
    out = jnp.zeros(len(gate_lists)) 
    
    for i, gate_list in enumerate(gate_lists):
        init = jnp.array([1.0,0.0], dtype=jnp.complex128)
        for g in gate_list:
            init = gates[g] @ init
        out = out.at[i].set(jnp.abs(init[0]))

    return out

params = dict(theta=-0.0, epsilon=0.001)
multiply_out(params)

Solution

  • The main issue here is that JAX does not support string inputs. But you can use NumPy to manipulate string arrays and turn them into integer categorical arrays that can then be used by jax.jit and jax.vmap. The solution might look something like this:

    import numpy as np
    
    def gates_func_int(params, gate_list_vals):
      g = gates_func(params)
      identity = jnp.eye(*list(g.values())[0].shape)
      return jnp.stack([g.get(val, identity) for val in gate_list_vals])
    
    @jax.jit
    def multiply_out_2(params):
      # compile-time pre-processing
      gate_lists = [["X90", "X90"], ["X90","Z90"], ["Z90", "X90"], ["X90","Z90","X90"]]
      max_size = max(map(len, gate_lists))
      gate_array = np.array([gates + [''] * (max_size - len(gates))
                            for gates in gate_lists])
      gate_list_vals, gate_list_ints = np.unique(gate_array, return_inverse=True)
      gate_list_ints = gate_list_ints.reshape(gate_array.shape)
    
      # runtime computation
      gates = gates_func_int(params, gate_list_vals)[gate_list_ints]
      initial = jnp.array([[1.0],[0.0]], dtype=jnp.complex128)
      return jax.vmap(lambda g: jnp.abs(jnp.linalg.multi_dot([*g, initial]))[0])(gates).ravel()
    
    multiply_out_2(params)