Search code examples
lambdaiteratorclosuresjax

What is a pure functional version of this Jax function?


I am writing a simple classifier in an attempt to come to grips with Jax. I want to normalize the data (it's the Iris dataset from sklearn) and my little function works, but from what I've read in the Jax the Sharp Bits documentation, I should avoid using lambda functions and iterating over vectors. I am not versed in functional programming, and I am curious if there's a better, more Jax idomatic way to do this. Here is my code so far:

import jax.numpy as jnp 
from jax import jit, vmap 
# lots of imports ... 

iris = load_it('data', 'iris.pkl')

def normalize(data): 
    return jnp.apply_along_axis(lambda x: x/jnp.linalg.norm(x), 1, data) 

# TODO: use a functional style, maybe use partial
# and get rid of the lambda ...

tic = time.perf_counter() 
iris_data_normal = normalize(iris.data) 
toc = time.perf_counter() 
print(f"It took jax {toc - tic:0.4f} seconds.")

When I run this I get: It took jax 0.0677 seconds. Any guidance is most appreciated!


Solution

  • Your current approach looks fine to me: it is pure (the function has no side-effects) and in JAX apply_along_axis is implemented in terms of vmap, so there's no problem in terms of computational efficiency.

    If you wanted to write a similar function using direct array operations, you could equivalently do something like this:

    def normalize(data):
      return data / jnp.linalg.norm(data, axis=1, keepdims=True)