Search code examples
pythonnumpymachine-learningdeep-learningjax

jax.lax.select vs jax.numpy.where


Was taking a look at the dropout implementation in flax:

def __call__(self, inputs, deterministic: Optional[bool] = None):
    """Applies a random dropout mask to the input.

    Args:
      inputs: the inputs that should be randomly masked.
      deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
        masked, whereas if true, no mask is applied and the inputs are returned
        as is.

    Returns:
      The masked inputs reweighted to preserve mean.
    """
    deterministic = merge_param(
        'deterministic', self.deterministic, deterministic)

    if (self.rate == 0.) or deterministic:
      return inputs

    # Prevent gradient NaNs in 1.0 edge-case.
    if self.rate == 1.0:
      return jnp.zeros_like(inputs)

    keep_prob = 1. - self.rate
    rng = self.make_rng(self.rng_collection)
    broadcast_shape = list(inputs.shape)
    for dim in self.broadcast_dims:
      broadcast_shape[dim] = 1
    mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
    mask = jnp.broadcast_to(mask, inputs.shape)
    return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))

Particularly, I'm interested in last line lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)). Wondering why lax.select is used here instead of:

return jnp.where(mask, inputs / keep_prob, 0)

or even more simply:

return mask * inputs / keep_prob

Solution

  • jnp.where is basically the same as lax.select, except more flexible in its inputs: for example, it will broadcast inputs to the same shape or cast to the same dtype, whereas lax.select requires more strict matching of inputs:

    >>> import jax.numpy as jnp
    >>> from jax import lax
    >>> x = jnp.arange(3)
    
    # Implicit broadcasting
    >>> jnp.where(x < 2, x[:, None], 0)
    DeviceArray([[0, 0, 0],
                 [1, 1, 0],
                 [2, 2, 0]], dtype=int32)
    
    >>> lax.select(x < 2, x[:, None], 0)
    TypeError: select cases must have the same shapes, got [(), (3, 1)].
    
    # Implicit type promotion
    >>> jnp.where(x < 2, jnp.zeros(3), jnp.arange(3))
    DeviceArray([0., 0., 2.], dtype=float32)
    
    >>> lax.select(x < 2, jnp.zeros(3), jnp.arange(3))
    TypeError: lax.select requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
    

    Library code is one place where the stricter semantics can be useful, because rather than smoothing-over potential implementation bugs and returning an unexpected output, it will complain loudly. But performance-wise (especially once JIT-compiled) the two are essentially equivalent.

    As for why the flax developers chose lax.select vs. multiplying by a mask, I can think of two reasons:

    1. Multiplying by a mask is subject to implicit type promotion semantics, and it takes a lot more thought to anticipate problematic outputs than a simple select, which is specifically-designed for the intended operation.
    2. Using multiplication causes the compiler to treat this operation as a multiplication, which it is not. A select is a much more narrow and precise operation than a multiplication, and by specifying operations precisely it often allows the compiler to optimize the results to a greater extent.