I have code to calculate the off-policy importance sampling estimate commonly used in reinforcement learning. It is not important to know what that is, but for someone who does it might help them understand this question a little better. Basically, I have a 1D array of instances of a custom Episode
class. An Episode
has four attributes, all of which are arrays of floats. I have a function which loops over all episodes and for each one, it does a computation based only on the arrays in that episode. The result of that computation is a float, which I then store in a result
array. Don't worry about what model.get_prob_this_action()
does, you can consider it a black box that takes two floats as input and returns a float. The code for this function before optimizing with JAX is:
def IS_estimate(model, theta, episodes):
""" Calculate the unweighted importance sampling estimate
for each episode in episodes.
Return as an array, one element per episode
"""
# episodes is an array of custom Python class instances
gamma = 1.0
result = np.zeros(len(episodes))
for ii, ep in enumerate(episodes):
obs = ep.observations # 1D array of floats
actions = ep.actions # 1D array of floats
rewards = ep.rewards # 1D array of floats
action_probs = ep.action_probs # 1D array of floats
pi_news = np.zeros(len(obs))
for jj in range(len(obs)):
pi_news[jj] = model.get_prob_this_action(obs[jj],actions[jj])
pi_ratio_prod = np.prod(pi_news / action_probs)
weighted_return = weighted_sum_gamma(rewards, gamma)
result[ii] = pi_ratio_prod * weighted_return
return np.array(result)
Unfortunately, I cannot just rewrite the function to work on a single episode and then use jax.vmap
to vectorize over that function. The reason is that the argument I want to vectorize is a custom Episode
object, which JAX won't support.
I can get rid of the inner loop to get pi_news
using vmap
, like:
def IS_estimate(model, theta, episodes):
""" Calculate the unweighted importance sampling estimate
for each episode in episodes.
Return as an array, one element per episode
"""
# episodes is an array of custom Python class instances
gamma = 1.0
result = np.zeros(len(episodes))
for ii, ep in enumerate(episodes):
obs = ep.observations # 1D array of floats
actions = ep.actions # 1D array of floats
rewards = ep.rewards # 1D array of floats
action_probs = ep.action_probs # 1D array of floats
vmapped_get_prob_this_action = vmap(model.get_prob_this_action,in_axes=(0,0))
pi_news = vmapped_get_prob_this_action(obs,actions)
pi_ratio_prod = np.prod(pi_news / action_probs)
weighted_return = weighted_sum_gamma(rewards, gamma)
result[ii] = pi_ratio_prod * weighted_return
return np.array(result)
and this does help some. But ideally, I'd like to vmap my outer loop as well. Does anyone know how I would do this?
The computation you're describing is an "array-of-structs" style computation; JAX's vmap
does not support this. What it does support is a "struct-of-arrays` style computation.
As a quick demonstration of this, here's how you might do a simple per-episode computation using first the array-of-structs pattern (with Python for-loops) and then the struct-of-arrays pattern (with jax.vmap
):
from typing import NamedTuple
import jax.numpy as jnp
import numpy as np
import jax
class Episode(NamedTuple):
observations: jnp.ndarray
actions: jnp.ndarray
def compute_result(self):
# stand-in for computing some value from attributes
return jnp.dot(self.observations, self.actions)
# Computing result per episode on array of structs:
rng = np.random.RandomState(42)
episodes = [
Episode(
observations=jnp.array(rng.rand(4)),
actions=jnp.array(rng.rand(4)))
for i in range(5)
]
result1 = jnp.array([ep.compute_result() for ep in episodes])
print(result1)
# [0.767802 0.83237386 0.49223748 0.5156544 1.1290307 ]
# Computing results on struct of arrays via vmap:
episodes_struct_of_arrays = Episode(
observations = jnp.vstack([ep.observations for ep in episodes]),
actions = jnp.vstack([ep.actions for ep in episodes])
)
result2 = jax.vmap(lambda self: self.compute_result())(episodes_struct_of_arrays)
print(result2)
# [0.767802 0.83237386 0.49223748 0.5156544 1.1290307 ]
If you want to use JAX's vmap
for this computation, you'll have to use a struct-of-arrays approach like the second one. Note that this also assumes that your Episode
class is registered as a pytree (see extending pytrees) which is true by default for NamedTuple
types.