0

I have code to calculate the off-policy importance sampling estimate commonly used in reinforcement learning. It is not important to know what that is, but for someone who does it might help them understand this question a little better. Basically, I have a 1D array of instances of a custom Episode class. An Episode has four attributes, all of which are arrays of floats. I have a function which loops over all episodes and for each one, it does a computation based only on the arrays in that episode. The result of that computation is a float, which I then store in a result array. Don't worry about what model.get_prob_this_action() does, you can consider it a black box that takes two floats as input and returns a float. The code for this function before optimizing with JAX is:

def IS_estimate(model, theta, episodes):
    """ Calculate the unweighted importance sampling estimate
    for each episode in episodes.
    Return as an array, one element per episode
    """
    # episodes is an array of custom Python class instances
    
    gamma = 1.0
    result = np.zeros(len(episodes))
    for ii, ep in enumerate(episodes):
        obs = ep.observations # 1D array of floats
        actions = ep.actions # 1D array of floats
        rewards = ep.rewards # 1D array of floats
        action_probs = ep.action_probs # 1D array of floats

        pi_news = np.zeros(len(obs))
        for jj in range(len(obs)):
            pi_news[jj] = model.get_prob_this_action(obs[jj],actions[jj])

        pi_ratio_prod = np.prod(pi_news / action_probs)

        weighted_return = weighted_sum_gamma(rewards, gamma)
        result[ii] = pi_ratio_prod * weighted_return

    return np.array(result)

Unfortunately, I cannot just rewrite the function to work on a single episode and then use jax.vmap to vectorize over that function. The reason is that the argument I want to vectorize is a custom Episode object, which JAX won't support.

I can get rid of the inner loop to get pi_news using vmap, like:

def IS_estimate(model, theta, episodes):
    """ Calculate the unweighted importance sampling estimate
    for each episode in episodes.
    Return as an array, one element per episode
    """
    # episodes is an array of custom Python class instances
    
    gamma = 1.0
    result = np.zeros(len(episodes))
    for ii, ep in enumerate(episodes):
        obs = ep.observations # 1D array of floats
        actions = ep.actions # 1D array of floats
        rewards = ep.rewards # 1D array of floats
        action_probs = ep.action_probs # 1D array of floats

        vmapped_get_prob_this_action = vmap(model.get_prob_this_action,in_axes=(0,0))
        pi_news = vmapped_get_prob_this_action(obs,actions)

        pi_ratio_prod = np.prod(pi_news / action_probs)

        weighted_return = weighted_sum_gamma(rewards, gamma)
        result[ii] = pi_ratio_prod * weighted_return

    return np.array(result)

and this does help some. But ideally, I'd like to vmap my outer loop as well. Does anyone know how I would do this?

marvin
  • 581
  • 5
  • 9

1 Answers1

1

The computation you're describing is an "array-of-structs" style computation; JAX's vmap does not support this. What it does support is a "struct-of-arrays` style computation.

As a quick demonstration of this, here's how you might do a simple per-episode computation using first the array-of-structs pattern (with Python for-loops) and then the struct-of-arrays pattern (with jax.vmap):

from typing import NamedTuple
import jax.numpy as jnp
import numpy as np
import jax

class Episode(NamedTuple):
  observations: jnp.ndarray
  actions: jnp.ndarray

  def compute_result(self):
    # stand-in for computing some value from attributes
    return jnp.dot(self.observations, self.actions)

# Computing result per episode on array of structs:
rng = np.random.RandomState(42)
episodes = [
    Episode(
        observations=jnp.array(rng.rand(4)),
        actions=jnp.array(rng.rand(4)))
    for i in range(5)
]
result1 = jnp.array([ep.compute_result() for ep in episodes])
print(result1)
# [0.767802   0.83237386 0.49223748 0.5156544  1.1290307 ]

# Computing results on struct of arrays via vmap:
episodes_struct_of_arrays = Episode(
    observations = jnp.vstack([ep.observations for ep in episodes]),
    actions = jnp.vstack([ep.actions for ep in episodes])
)
result2 = jax.vmap(lambda self: self.compute_result())(episodes_struct_of_arrays)
print(result2)
# [0.767802   0.83237386 0.49223748 0.5156544  1.1290307 ]

If you want to use JAX's vmap for this computation, you'll have to use a struct-of-arrays approach like the second one. Note that this also assumes that your Episode class is registered as a pytree (see extending pytrees) which is true by default for NamedTuple types.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Hi Jake, thanks for the quick response. This is a neat approach, and will come in handy in other places for me. However, a `vstack()` fails for this case because `ep.observations` is not the same length for each episode. In reinforcement learning, episodes often have different lengths. The observations, actions, etc. of the *same* episode are guaranteed to be the same length, though, at least for the problems I am dealing with. I could maybe pad the shorter episodes with some value, but I'd prefer not to if possible. – marvin Oct 02 '22 at 00:24