-2

I am currently implementing an algorithm with GPflow using GPR. I wanted to save the parameters after the GPR training and load the model for testing. Does anyone knows the command?

joel
  • 6,359
  • 2
  • 30
  • 55

2 Answers2

3

GPflow has a page with tips & tricks now. You can follow the link where you will find the answer on your question. But, I'm going to paste MWE here as well:

Let's say you want to store GPR model, you can do it with gpflow.Saver():

kernel = gpflow.kernels.RBF(1)
x = np.random.randn(100, 1)
y = np.random.randn(100, 1)
model = gpflow.models.GPR(x, y, kernel)

filename = "/tmp/gpr.gpflow"
path = Path(filename)
if path.exists():
    path.unlink()
saver = gpflow.saver.Saver()
saver.save(filename, model)

To load it back you have to use either this solution:

with tf.Graph().as_default() as graph, tf.Session().as_default():
    model_copy = saver.load(filename)

or if you want to load the model in the same session where you stored it before, you need to apply some tricks:

ctx_for_loading = gpflow.saver.SaverContext(autocompile=False)
model_copy = saver.load(filename, context=ctx_for_loading)
model_copy.clear()
model_copy.compile()

UPDATE 1 June 2020:

GPflow 2.0 doesn't provide custom saver. It relies on TensorFlow checkpointing and tf.saved_model. You can find examples here: GPflow intro.

Artem Artemev
  • 516
  • 3
  • 8
0

One option that I employ for gpflow models is to just save and load the trainables. It assumes you have a function that builds and compiles the model. I show this in the following, by saving the variables to an hdf5 file.

import h5py

def _load_model(model, load_file):
    """
    Load a model given by model path
    """

    vars = {}
    def _gather(name, obj):
        if isinstance(obj, h5py.Dataset):
            vars[name] = obj[...]

    with h5py.File(load_file) as f:
        f.visititems(_gather)

    model.assign(vars)

def _save_model(model, save_file):
    vars = model.read_trainables()
    with h5py.File(save_file) as f:
        for name, value in vars.items():
            f[name] = value
Josh Albert
  • 1,064
  • 13
  • 16