I have a train_step
function like this:
@tf.function
def train_step(timestep_values,noised_image,noise):
# calculate loss and update parameters
with tf.GradientTape() as tape:
prediction = model(noised_image, timestep_values)
loss_value = loss_of(noise, prediction)
gradients = tape.gradient(loss_value, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
tf.print("end-train-step")
And a main loop like this:
EPOCHS = 1
for e in range(EPOCHS):
for batch in X_train:
rng, tsrng = np.random.randint(0, 100000, size=(2,))
timestep_values = generate_timestamp(tsrng, batch.shape[0])
noised_image, noise = forward_noise(rng, batch, timestep_values)
train_step(timestep_values,noised_image,noise)
print("end-of-batch")
print(f"Epoch {e+1}/{EPOCHS}")
My problem is that tf.print("end-train-step")
is printed very fast, but print("end-of-batch")
is not displayed, at least it is displayed after 2-3 minutes of waiting (on collab/Kaggle). Why?
I don't understand why the train_step
function is executed very fast, but when returning the loss_value, everything become SLOW. Why is this transition so slow?
I tried to remove all tf.print
and print
in my functions but this doesn't help. I also tried to execute this code on my computer on my CPU (r5 3600) and it was faster than collab or kaggle.