Search code examples
pythonoptimizationjax

Mapping a vector to a matrix in JAX


I want to optimize with JAX an elements of a vector with a loss function that is a function of a matrix built by the elements of said vector. Specifically, the element of the matrix n,m correspond to the element n+m of the vector. I have tried

def get_F_matrix(vector):
    N = vector.shape[0]
    F = jnp.zeros((N//2,N//2))
    for i in range(N//2):
        for j in range(N//2):
            F = F.at[i,j].set(vector[i + j])  
    return F

but this is taking a very long time when the vector is of significant size. Does anyone know if there is a way to directly map the vector to the matrix efficiently?


Solution

  • Seems like you're looking for a moving window function. Code from this GitHub comment:

    from functools import partial
    import jax
    import jax.numpy as jnp
    from jax import jit, vmap
    
    @partial(jit, static_argnums=(1,))
    def moving_window(a, size: int):
        starts = jnp.arange(len(a) - size + 1)
        return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
    

    This uses pure JAX while your code runs pure Python loops, so the JAX version should be much faster.

    EDIT: Indeed, it is faster:

    In [7]: vector = jax.numpy.arange(0, 100)
    
    In [8]: %timeit get_F_matrix(vector)
    9.14 s ± 47.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [9]: %timeit moving_window(vector, 100//2)[:-1]
    671 µs ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    

    Example:

    >>> vector = jax.numpy.arange(0, 10)
    >>> get_F_matrix(vector)
    DeviceArray([[0., 1., 2., 3., 4.],
                 [1., 2., 3., 4., 5.],
                 [2., 3., 4., 5., 6.],
                 [3., 4., 5., 6., 7.],
                 [4., 5., 6., 7., 8.]], dtype=float32)
    >>> moving_window(vector, 10//2)[:-1] # chop off the last row
    DeviceArray([[0, 1, 2, 3, 4],
                 [1, 2, 3, 4, 5],
                 [2, 3, 4, 5, 6],
                 [3, 4, 5, 6, 7],
                 [4, 5, 6, 7, 8]], dtype=int32)
    

    Your function returns an array of float32 while moving_window preserves the type of the original vector, but it should be simple to convert the resulting array to the type you need.