0

I am training multiple models in the same Colab notebook to compare some results. I've written a function to avoid repeating code, and I've added WandbCallback() in the list of callbacks for model_name.fit().

def generic_FE_trainer(model_name, checkpoint_filename, min_lr=1e-7):
    earlystop = callbacks.EarlyStopping(monitor="val_loss",
                                    patience=11, 
                                    verbose=1)
    lr_reduction = callbacks.ReduceLROnPlateau(monitor='val_accuracy', 
                                           patience=5, 
                                           verbose=1, 
                                           factor=0.8, 
                                           min_lr=min_lr)
    checkpoint_dir = os.path.join(save_dir, checkpoint_filename)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint = callbacks.ModelCheckpoint(filepath=os.path.join(checkpoint_dir, checkpoint_filename+'.h5'), 
                                           monitor="val_loss",
                                           verbose=1,
                                           save_best_only=True,
                                           save_weights_only=False)
    return model_name.fit(train_gen, 
                          epochs=40,
                          batch_size=4,
                          verbose=1,
                          callbacks=[earlystop, checkpoint, lr_reduction, WandbCallback()],
                          validation_data=val_gen)

I call them with something like the code below, but then my project dashboard puts the data for both on the same graphs, where I've attached a graph for epochs (although it's not particularly useful) to show as an example.

history1 = generic_FE_trainer(model1, 'model1')
history2 = generic_FE_trainer(model2, 'model2')

wandb epochs

This is the same for all my metrics, so how can I have wandb plot these graphs separately? I would like them to be in different runs, if that's possible.

mpnm
  • 481
  • 1
  • 7
  • 13

1 Answers1

0

Use w = wandb.init() again to spawn a new process, and make sure to pass in reinit=True. The epochs are connected here because the first run never ends. To end a run, use w.finish().

So here, different runs would spawn with

w = wandb.init(reinit=True)
history1 = generic_FE_trainer(model1, 'model1')
w.finish()
wandb.init()
history2 = generic_FE_trainer(model2, 'model2')
mpnm
  • 481
  • 1
  • 7
  • 13