Search code examples
pythonvectorparallel-processingvectorizationjax

What is the correct way to define a vectorized (jax.vmap) function in a class?


I want to add a function, which is vectorized by jax.vmap, as a class method. However, I am not sure where to define this function within the class. My main goal is to avoid, that the function is being redefined each time I call the class method.

Here is a minimal example for a class that counts how often a value occurs in a jnp.array, with a non-vectorized and vectorized version:

import jax.numpy as jnp
import jax

class ValueCounter():

    def __init__(self): # for completeness, not used
        self.attribute_1 = None

    @staticmethod
    def _count_value_in_array( # non-vectorized function
        array: jnp.array, value: float
    ) -> jnp.array:
        """Count how often a value occurs in an array"""
        return jnp.count_nonzero(array == value)

    # here comes the vectorized function
    def count_values_in_array(self, array: jnp.array, value_array: jnp.array) -> jnp.array:
        """Count how often each value in an array of values occurs in an array"""
        count_value_in_array_vec = jax.vmap(
            self._count_value_in_array, in_axes=(None, 0)
        ) # vectorized function is defined again each time the function is called
        return count_value_in_array_vec(array, value_array)

Example output & input:

value_counter = ValueCounter()
value_counter.count_values_in_array(jnp.array([0, 1, 2, 2, 1, 1]), jnp.array([0, 1, 2]))

Result (correct as expected)

Array([1, 3, 2], dtype=int32)

The vectorized function count_value_in_array_vecis redefined each time count_values_in_array - which seems unnecessary to me. However, I am a bit stuck on how to avoid this. Does someone know how the vectorized function could be integrated into the class in a more elegant way?


Solution

  • You can decorate the static method directly; for example:

    from functools import partial
    
    # ...
    
        @staticmethod
        @partial(jax.vmap, in_axes=(None, 0))
        def _count_value_in_array(
            array: jnp.array, value: float
        ) -> jnp.array:
            """Count how often a value occurs in an array"""
            return jnp.count_nonzero(array == value)
    # ...