1

I'd like to use the rllib trained policy model in a different code where I need to track which action is generated for specific input states. Using a standard TensorFlow or PyTorch (preferred) network model would provide that flexibility but I can't find clear documentation on how to produce a usable dat or H5 file from a trained rllib agent that I can then load into a torch or tf/Keras model.

lakehopper
  • 11
  • 2
  • lakehopper is correct, there seems to be no documentation on how to save a model - it seems to be a huge omission. One should be able to set export_formats=[ExportFormat.H5] (https://github.com/ray-project/ray/issues/17319) if using tune - but that does not work either. This problem is also asked here (https://discuss.ray.io/t/save-model-parameters-on-each-checkpoint/2892). Overriding 'save' in a custom model does not seem to do anything; there is no documentation on how to invoke it automatically. model.save would be a perfect one liner to achieve this. Can you please provide an example or po – treadzero Jul 26 '21 at 08:40

1 Answers1

1

The easiest way to get the weights from a checkpoint is to load it again with rllib and then save it with the Tensorflow/Pytorch commands. If you have a keras TF model you can simply call:

model.save('my_model.h5') # creates a HDF5 file
Rocket
  • 1,030
  • 5
  • 24