I am writing a simple classifier in an attempt to come to grips with Jax. I want to normalize the data (it's the Iris dataset from sklearn) and my little function works, but from what I've read in the Jax the Sharp Bits documentation, I should avoid using lambda functions and iterating over vectors. I am not versed in functional programming, and I am curious if there's a better, more Jax idomatic way to do this. Here is my code so far:
import jax.numpy as jnp
from jax import jit, vmap
# lots of imports ...
iris = load_it('data', 'iris.pkl')
def normalize(data):
return jnp.apply_along_axis(lambda x: x/jnp.linalg.norm(x), 1, data)
# TODO: use a functional style, maybe use partial
# and get rid of the lambda ...
tic = time.perf_counter()
iris_data_normal = normalize(iris.data)
toc = time.perf_counter()
print(f"It took jax {toc - tic:0.4f} seconds.")
When I run this I get: It took jax 0.0677 seconds.
Any guidance is most appreciated!
Your current approach looks fine to me: it is pure (the function has no side-effects) and in JAX apply_along_axis
is implemented in terms of vmap
, so there's no problem in terms of computational efficiency.
If you wanted to write a similar function using direct array operations, you could equivalently do something like this:
def normalize(data):
return data / jnp.linalg.norm(data, axis=1, keepdims=True)