I'm using Ray RLlib to train a PPO agent with TWO modifications on the PPOTFPolicy.
- I added a mixin class (say "Recal") to the "mixins" parameter in "build_tf_policy()". This way, the PPOTFPolicy would subclass my "Recal" class and have the access of the member functions that I defined in "Recal". My "Recal" class is a simple subclass of tf.keras.Model.
- I defined a "my_postprocess_fn" function to replace the "compute_gae_for_sample_batch" function which is given to the parameter "postprocess_fn" in "build_tf_policy()".
The "PPOTrainer=build_trainer(...)" function keeps unchanged. I use framework="tf", and make eager mode disabled.
Psuedo code is below. Here is a running version at colab.
tf.compat.v1.disable_eager_execution()
class Recal:
def __init__(self):
self.recal_model = build_and_compile_keras_model()
def my_postprocess_fn(policy, sample_batch):
with policy.model.graph.as_default():
sample_batch = policy.recal_model.predict(sample_batch)
return compute_gae_for_sample_batch(policy, sample_batch)
PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")
for i in range(1):
result = ppo_trainer.train()
This way "Recal" class is a base class of PPOTFPolicy, and when an instance of the PPOTFPolicy is created, "Recal" is instantiated within the same tensorflow graph. But when my_postprocess_fn() is called, it raises an error (see below).
tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
[[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]