I have code to calculate the off-policy importance sampling estimate commonly used in reinforcement learning. It is not important to know what that is, but for someone who does it might help them understand this question a little better. Basically, I have a 1D array of instances of a custom Episode
class. An Episode
has four attributes, all of which are arrays of floats. I have a function which loops over all episodes and for each one, it does a computation based only on the arrays in that episode. The result of that computation is a float, which I then store in a result
array. Don't worry about what model.get_prob_this_action()
does, you can consider it a black box that takes two floats as input and returns a float. The code for this function before optimizing with JAX is:
def IS_estimate(model, theta, episodes):
""" Calculate the unweighted importance sampling estimate
for each episode in episodes.
Return as an array, one element per episode
"""
# episodes is an array of custom Python class instances
gamma = 1.0
result = np.zeros(len(episodes))
for ii, ep in enumerate(episodes):
obs = ep.observations # 1D array of floats
actions = ep.actions # 1D array of floats
rewards = ep.rewards # 1D array of floats
action_probs = ep.action_probs # 1D array of floats
pi_news = np.zeros(len(obs))
for jj in range(len(obs)):
pi_news[jj] = model.get_prob_this_action(obs[jj],actions[jj])
pi_ratio_prod = np.prod(pi_news / action_probs)
weighted_return = weighted_sum_gamma(rewards, gamma)
result[ii] = pi_ratio_prod * weighted_return
return np.array(result)
Unfortunately, I cannot just rewrite the function to work on a single episode and then use jax.vmap
to vectorize over that function. The reason is that the argument I want to vectorize is a custom Episode
object, which JAX won't support.
I can get rid of the inner loop to get pi_news
using vmap
, like:
def IS_estimate(model, theta, episodes):
""" Calculate the unweighted importance sampling estimate
for each episode in episodes.
Return as an array, one element per episode
"""
# episodes is an array of custom Python class instances
gamma = 1.0
result = np.zeros(len(episodes))
for ii, ep in enumerate(episodes):
obs = ep.observations # 1D array of floats
actions = ep.actions # 1D array of floats
rewards = ep.rewards # 1D array of floats
action_probs = ep.action_probs # 1D array of floats
vmapped_get_prob_this_action = vmap(model.get_prob_this_action,in_axes=(0,0))
pi_news = vmapped_get_prob_this_action(obs,actions)
pi_ratio_prod = np.prod(pi_news / action_probs)
weighted_return = weighted_sum_gamma(rewards, gamma)
result[ii] = pi_ratio_prod * weighted_return
return np.array(result)
and this does help some. But ideally, I'd like to vmap my outer loop as well. Does anyone know how I would do this?