0

I am trying to make ray tune with wandb stop the experiment under certain conditions.

  • stop all experiment if any trial raises an Exception (so i can fix the code and resume)
  • stop if my score gets -999
  • stop if the variable varcannotbezero gets 0

The following things i tried all failed in achieving desired behavior:

  • stop={"score":-999,"varcannotbezero":0}
  • max_failures=0
  • defining a Stoper class did also not work
class RayStopper(Stopper):
    def __init__(self):
        self._start = time.time()
        #self._deadline = 300
    def __call__(self, trial_id, result):
        self.score=result["score"]
        self.varcannotbezero=result["varcannotbezero"]
        return False
    def stop_all(self):
        if self.score==-999 or self.varcannotbezero==0:
            return True
        else:
            return False

Ray tune just continues to run

    wandb_project="ABC"
    wandb_api_key="KEY"
    ray.init(configure_logging=False)

    if current_best_params is None:
        algo = HyperOptSearch()
    else:
        algo = HyperOptSearch(points_to_evaluate=current_best_params,n_initial_points=n_initial_points)
    algo = ConcurrencyLimiter(algo, max_concurrent=1)

    scheduler = AsyncHyperBandScheduler()
    analysis = tune.run(
        tune_obj,
        name="Name",
        resources_per_trial={"cpu": 1},
        search_alg=algo,
        scheduler=scheduler,
        metric="score",
        mode="max",
        num_samples=10,
        stop={"score":-999,"varcannotbezero":0},
        max_failures=0,
        config=config,
        callbacks=[WandbLoggerCallback(project=wandb_project,entity="mycompany",api_key=wandb_api_key,log_config=True)],
        local_dir=local_dir,
        resume="AUTO",
        verbose=0
    )

user670186
  • 2,588
  • 6
  • 37
  • 55

1 Answers1

0

I found a solution to stop the experiment with a customer Stopper class. However, the experiment will just stop, and I didnt find a way to resume it to continue :(

class RayStopper(Stopper):
    def __init__(self):
        self._start = time.time()
        self.scoretostop=0
    def __call__(self, trial_id, result):
        self.scoretostop=result["scoretostop"]
        return False
    def stop_all(self):
        secs=int(time.time())
        runtime=secs - self._start
        if secs % 20 == 0:
            print(f"-----------------RayStopper--------------")
            print(f"runtime={runtime}")
            print(f"scoretostop={self.scoretostop}")
        if self.scoretostop==1:
            return True
        else:
            return False
user670186
  • 2,588
  • 6
  • 37
  • 55