2

I'm using Ray RLlib to train a PPO agent with TWO modifications on the PPOTFPolicy.

  • I added a mixin class (say "Recal") to the "mixins" parameter in "build_tf_policy()". This way, the PPOTFPolicy would subclass my "Recal" class and have the access of the member functions that I defined in "Recal". My "Recal" class is a simple subclass of tf.keras.Model.
  • I defined a "my_postprocess_fn" function to replace the "compute_gae_for_sample_batch" function which is given to the parameter "postprocess_fn" in "build_tf_policy()".

The "PPOTrainer=build_trainer(...)" function keeps unchanged. I use framework="tf", and make eager mode disabled.

Psuedo code is below. Here is a running version at colab.

tf.compat.v1.disable_eager_execution()

class Recal:
    def __init__(self):
        self.recal_model = build_and_compile_keras_model()

def my_postprocess_fn(policy, sample_batch):
    with policy.model.graph.as_default():
        sample_batch = policy.recal_model.predict(sample_batch)
    return compute_gae_for_sample_batch(policy, sample_batch)

PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")

for i in range(1):
    result = ppo_trainer.train()

This way "Recal" class is a base class of PPOTFPolicy, and when an instance of the PPOTFPolicy is created, "Recal" is instantiated within the same tensorflow graph. But when my_postprocess_fn() is called, it raises an error (see below).

tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
     [[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]
WPXP
  • 21
  • 1

1 Answers1

0

I have been exploring with Ray for a while now. So I think I can give you an answer for this question.

Ray uses it's own version of Model class. And this class does not have tf.keras.Model.predict method to get batch predictions. However it does provide other options.

I am yet to find out if the output of both classes are equivalent or not. During searching an answer to that question only I come across your question. If you see this, I would be happy to continue the conversation. :)

spramuditha
  • 357
  • 2
  • 9