Search code examples
pythonreinforcement-learningjax

How to use JAX vmap to efficiently calculate importance sampling estimate


I have code to calculate the off-policy importance sampling estimate commonly used in reinforcement learning. It is not important to know what that is, but for someone who does it might help them understand this question a little better. Basically, I have a 1D array of instances of a custom Episode class. An Episode has four attributes, all of which are arrays of floats. I have a function which loops over all episodes and for each one, it does a computation based only on the arrays in that episode. The result of that computation is a float, which I then store in a result array. Don't worry about what model.get_prob_this_action() does, you can consider it a black box that takes two floats as input and returns a float. The code for this function before optimizing with JAX is:

def IS_estimate(model, theta, episodes):
    """ Calculate the unweighted importance sampling estimate
    for each episode in episodes.
    Return as an array, one element per episode
    """
    # episodes is an array of custom Python class instances
    
    gamma = 1.0
    result = np.zeros(len(episodes))
    for ii, ep in enumerate(episodes):
        obs = ep.observations # 1D array of floats
        actions = ep.actions # 1D array of floats
        rewards = ep.rewards # 1D array of floats
        action_probs = ep.action_probs # 1D array of floats

        pi_news = np.zeros(len(obs))
        for jj in range(len(obs)):
            pi_news[jj] = model.get_prob_this_action(obs[jj],actions[jj])

        pi_ratio_prod = np.prod(pi_news / action_probs)

        weighted_return = weighted_sum_gamma(rewards, gamma)
        result[ii] = pi_ratio_prod * weighted_return

    return np.array(result)

Unfortunately, I cannot just rewrite the function to work on a single episode and then use jax.vmap to vectorize over that function. The reason is that the argument I want to vectorize is a custom Episode object, which JAX won't support.

I can get rid of the inner loop to get pi_news using vmap, like:

def IS_estimate(model, theta, episodes):
    """ Calculate the unweighted importance sampling estimate
    for each episode in episodes.
    Return as an array, one element per episode
    """
    # episodes is an array of custom Python class instances
    
    gamma = 1.0
    result = np.zeros(len(episodes))
    for ii, ep in enumerate(episodes):
        obs = ep.observations # 1D array of floats
        actions = ep.actions # 1D array of floats
        rewards = ep.rewards # 1D array of floats
        action_probs = ep.action_probs # 1D array of floats

        vmapped_get_prob_this_action = vmap(model.get_prob_this_action,in_axes=(0,0))
        pi_news = vmapped_get_prob_this_action(obs,actions)

        pi_ratio_prod = np.prod(pi_news / action_probs)

        weighted_return = weighted_sum_gamma(rewards, gamma)
        result[ii] = pi_ratio_prod * weighted_return

    return np.array(result)

and this does help some. But ideally, I'd like to vmap my outer loop as well. Does anyone know how I would do this?


Solution

  • The computation you're describing is an "array-of-structs" style computation; JAX's vmap does not support this. What it does support is a "struct-of-arrays` style computation.

    As a quick demonstration of this, here's how you might do a simple per-episode computation using first the array-of-structs pattern (with Python for-loops) and then the struct-of-arrays pattern (with jax.vmap):

    from typing import NamedTuple
    import jax.numpy as jnp
    import numpy as np
    import jax
    
    class Episode(NamedTuple):
      observations: jnp.ndarray
      actions: jnp.ndarray
    
      def compute_result(self):
        # stand-in for computing some value from attributes
        return jnp.dot(self.observations, self.actions)
    
    # Computing result per episode on array of structs:
    rng = np.random.RandomState(42)
    episodes = [
        Episode(
            observations=jnp.array(rng.rand(4)),
            actions=jnp.array(rng.rand(4)))
        for i in range(5)
    ]
    result1 = jnp.array([ep.compute_result() for ep in episodes])
    print(result1)
    # [0.767802   0.83237386 0.49223748 0.5156544  1.1290307 ]
    
    # Computing results on struct of arrays via vmap:
    episodes_struct_of_arrays = Episode(
        observations = jnp.vstack([ep.observations for ep in episodes]),
        actions = jnp.vstack([ep.actions for ep in episodes])
    )
    result2 = jax.vmap(lambda self: self.compute_result())(episodes_struct_of_arrays)
    print(result2)
    # [0.767802   0.83237386 0.49223748 0.5156544  1.1290307 ]
    

    If you want to use JAX's vmap for this computation, you'll have to use a struct-of-arrays approach like the second one. Note that this also assumes that your Episode class is registered as a pytree (see extending pytrees) which is true by default for NamedTuple types.