Search code examples
pythonjax

Negative Sampling in JAX


I'm implementing a negative sampling algorithm in JAX. The idea is to sample negatives from a range excluding from this range a number of non-acceptable outputs. My current solution is close to the following:

import jax.numpy as jnp
import jax
max_range = 5
n_samples = 2
true_cases = jnp.array(
    [
        [1,2],
        [1,4],
        [0,5]
    ]
)
# i combine the true cases in a dictionary of the following form:
non_acceptable_as_negatives = {
    0: jnp.array([5]),
    1: jnp.array([2,4]),
    2: jnp.array([]),
    3: jnp.array([]),
    4: jnp.array([]),
    5: jnp.array([])
}
negatives = []
key = jax.random.PRNGKey(42)
for i in true_cases[:,0]:
    key,use_key  = jax.random.split(key,2)
    p = jnp.ones((max_range+1,))
    p = p.at[non_acceptable_as_negatives[int(i)]].set(0)
    p = p / p.sum()
    negatives.append(
        jax.random.choice(use_key,
            jnp.arange(max_range+1),
            (1, n_samples),
            replace=False,
            p=p,
            )
    )

However this seems

  • rather complicated and
  • is not very performant as the true cases in the original contain ~200_000 entries and max range is ~ 50_000.

How can I improve this solution? And is there a more JAX way to store arrays of varying size which I currently store in the non_acceptable_as_negatives dict?


Solution

  • You'll generally achieve better performance in JAX (as in NumPy) if you can avoid loops and use vectorized operations instead. If I'm understanding your function correctly, I think the following does roughly the same thing, but using vmap.

    Since JAX does not support dictionary lookups based on traced values, I replaced your dict with a padded array

    import jax.numpy as jnp
    import jax
    max_range = 5
    n_samples = 2
    fill_value = max_range + 1
    
    true_cases = jnp.array([
      [1,2],
      [1,4],
      [0,5]
    ])
    
    non_acceptable_as_negatives = jnp.array([
        [5, fill_value],
        [2, 4],
    ])
    
    @jax.vmap
    def func(key, true_case):
      p = jnp.ones(max_range + 1)
      idx = true_cases[0]
      replace = non_acceptable_as_negatives.at[idx].get(fill_value=fill_value)
      p = p.at[replace].set(0, mode='drop')
      return jax.random.choice(key, max_range + 1, (n_samples,), replace=False, p=p)
    
    
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, len(true_cases))
    result = func(keys, true_cases)
    print(result)
    
    [[3 1]
     [5 1]
     [1 5]]