I am new to using ray.tune. I already have my network written in a modular format and now I am trying to incorporate ray.tune, but I do not know where to initialize the model (vs updating the perturbed hyperparameters) so that the model and the weights are not re-initialized when a worker is truncated and replaced by a better performing worker.
Background
I am using the PBT scheduler of ray.tune which creates num_samples number of models (workers) each of which are initialized with a different set of sampled hyperparameters. When a model is evaluated, if it is performing poorly, it will be stopped and load the checkpoint of one of the top performing workers. Once it is loaded (this is a deep copy of the network), the hyperparameters are perturbed and then it will train until the next evaluation.
The MyTrainable class should have a _setup, _train, _save, and _restore function. The setup calls for a config variable and this is where the newly sampled hyperparameters are implemented.
My question is where should be the original model be defined? I can easily implement the updated HPs in this section. But I have not seen anywhere in the documentation where I can pass a pre-defined model into the ray.tune.run function. If I keep the create_model() function in the _setup() though, it will eliminate the previously trained weights which is part of the benefit of this method.
Code
Here are the 3 functions I have:
self._hyperparameters(config) # redefines the self.opt options accoring to the new perturbations
self.model.update_optimizer(self.opt) # redefines the optimizers using the new learning rates and the beta values for Adam
self.model = create_model(self.opt) # Original function that defines the initial model and initializes the weights