-2

I have a train_step function like this:

@tf.function
def train_step(timestep_values,noised_image,noise):
    # calculate loss and update parameters
    with tf.GradientTape() as tape:
        prediction = model(noised_image, timestep_values)
        loss_value = loss_of(noise, prediction)
    gradients = tape.gradient(loss_value, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    tf.print("end-train-step")

And a main loop like this:

EPOCHS = 1
for e in range(EPOCHS):
    for batch in X_train:
        rng, tsrng = np.random.randint(0, 100000, size=(2,))
        timestep_values = generate_timestamp(tsrng, batch.shape[0])
        noised_image, noise = forward_noise(rng, batch, timestep_values)
        train_step(timestep_values,noised_image,noise)
        print("end-of-batch")
    print(f"Epoch {e+1}/{EPOCHS}")

My problem is that tf.print("end-train-step") is printed very fast, but print("end-of-batch") is not displayed, at least it is displayed after 2-3 minutes of waiting (on collab/Kaggle). Why?

I don't understand why the train_step function is executed very fast, but when returning the loss_value, everything become SLOW. Why is this transition so slow?

I tried to remove all tf.print and print in my functions but this doesn't help. I also tried to execute this code on my computer on my CPU (r5 3600) and it was faster than collab or kaggle.

Benjamin Buch
  • 4,752
  • 7
  • 28
  • 51
Tikai7
  • 17
  • 4

1 Answers1

0

ok guys, i just forget to activate the GPU on Collab/Kaggle, however i still don't understand why the transition between the two function is slow but anyway, it's working well now !

Tikai7
  • 17
  • 4
  • Your answer could be improved with additional supporting information. Please [edit] to add further details, such as citations or documentation, so that others can confirm that your answer is correct. You can find more information on how to write good answers [in the help center](/help/how-to-answer). – Community Apr 26 '23 at 04:14