I want to add a function, which is vectorized by jax.vmap
, as a class method. However, I am not sure where to define this function within the class. My main goal is to avoid, that the function is being redefined each time I call the class method.
Here is a minimal example for a class that counts how often a value occurs in a jnp.array
, with a non-vectorized and vectorized version:
import jax.numpy as jnp
import jax
class ValueCounter():
def __init__(self): # for completeness, not used
self.attribute_1 = None
@staticmethod
def _count_value_in_array( # non-vectorized function
array: jnp.array, value: float
) -> jnp.array:
"""Count how often a value occurs in an array"""
return jnp.count_nonzero(array == value)
# here comes the vectorized function
def count_values_in_array(self, array: jnp.array, value_array: jnp.array) -> jnp.array:
"""Count how often each value in an array of values occurs in an array"""
count_value_in_array_vec = jax.vmap(
self._count_value_in_array, in_axes=(None, 0)
) # vectorized function is defined again each time the function is called
return count_value_in_array_vec(array, value_array)
Example output & input:
value_counter = ValueCounter()
value_counter.count_values_in_array(jnp.array([0, 1, 2, 2, 1, 1]), jnp.array([0, 1, 2]))
Result (correct as expected)
Array([1, 3, 2], dtype=int32)
The vectorized function count_value_in_array_vec
is redefined each time count_values_in_array
- which seems unnecessary to me. However, I am a bit stuck on how to avoid this.
Does someone know how the vectorized function could be integrated into the class in a more elegant way?
You can decorate the static method directly; for example:
from functools import partial
# ...
@staticmethod
@partial(jax.vmap, in_axes=(None, 0))
def _count_value_in_array(
array: jnp.array, value: float
) -> jnp.array:
"""Count how often a value occurs in an array"""
return jnp.count_nonzero(array == value)
# ...