0

I am trying to perform Multi-Task Learning on the UTKFace Dataset. I can get the training done just fine. However, when I try to use graph execution using tf.function it doesn't work as the gradients returned is None. I am using ImageDataGenerator.flow_from_dataframe to create my train and validation loaders. Any and all help will be appreciated. The public Kaggle notebook/kernel for reproducible results can be found HERE

def loss_fn(y_batch_train, logits) :
    loss_1 = race_loss(y_batch_train[2], np.asarray(logits[2]))
    loss_2 = gender_loss(y_batch_train[1], np.asarray(logits[1]))
    loss_3 = age_loss(y_batch_train[0], np.asarray(logits[0]))
    
    loss = loss_1 + loss_2 + loss_3
    losses = [loss_1, loss_2, loss_3]
    
    return loss, losses


epochs = 2

for epoch in range(2) :
    print('EPOCH', epoch)
    
    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataloader) :
        
        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
       
        # Run the forward pass of the layer.
        # The operations that the layer applies
        # to its inputs are going to be recorded
        # on the GradientTape.
        # Compute the loss value for this minibatch as well.
        loss_value = train_step(x_batch_train, y_batch_train)
        
        # race_metrics.update_state(y_batch_train, logits)
        # gender_metrics.update_state(y_batch_train, logits)
        # age_metrics.update_state(y_batch_train, logits)
        
    print('Training Losses: Race: ' + str(losses[0].numpy()) + ', Gender: ' + str(losses[1].numpy()) + ', Age: ' + str(losses[2].numpy()))

    for x_batch_val, y_batch_val in valid_dataloader:
        val_loss_value, val_losses = valid_step(x_batch_val, y_batch_val)

    print('Validation Losses: Race: ' + str(val_losses[0].numpy()) + ', Gender: ' + str(val_losses[1].numpy()) + ', Age: ' + str(val_losses[2].numpy()))

Stevi G
  • 257
  • 1
  • 4
  • 13

0 Answers0