I want to optimize with JAX an elements of a vector with a loss function that is a function of a matrix built by the elements of said vector. Specifically, the element of the matrix n,m correspond to the element n+m of the vector. I have tried
def get_F_matrix(vector):
N = vector.shape[0]
F = jnp.zeros((N//2,N//2))
for i in range(N//2):
for j in range(N//2):
F = F.at[i,j].set(vector[i + j])
return F
but this is taking a very long time when the vector is of significant size. Does anyone know if there is a way to directly map the vector to the matrix efficiently?
Seems like you're looking for a moving window function. Code from this GitHub comment:
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
This uses pure JAX while your code runs pure Python loops, so the JAX version should be much faster.
EDIT: Indeed, it is faster:
In [7]: vector = jax.numpy.arange(0, 100)
In [8]: %timeit get_F_matrix(vector)
9.14 s ± 47.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [9]: %timeit moving_window(vector, 100//2)[:-1]
671 µs ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Example:
>>> vector = jax.numpy.arange(0, 10)
>>> get_F_matrix(vector)
DeviceArray([[0., 1., 2., 3., 4.],
[1., 2., 3., 4., 5.],
[2., 3., 4., 5., 6.],
[3., 4., 5., 6., 7.],
[4., 5., 6., 7., 8.]], dtype=float32)
>>> moving_window(vector, 10//2)[:-1] # chop off the last row
DeviceArray([[0, 1, 2, 3, 4],
[1, 2, 3, 4, 5],
[2, 3, 4, 5, 6],
[3, 4, 5, 6, 7],
[4, 5, 6, 7, 8]], dtype=int32)
Your function returns an array of float32
while moving_window
preserves the type of the original vector, but it should be simple to convert the resulting array to the type you need.