Having disabled eager execution, I am able to connect to my cloud TPU and run my custom training loop. After calculating the loss, I would like to print that variable.
Given that loss is a tensor of a Cloud TPU, I haven't found any way to print it till now. tf.print
returns a PrintOperation
, so I am stuck. I guess something like moving the loss tensor back to my cpu would work, but I haven't found any hacky way to do that.
I know it is possible to get that result because when I encapsulate my model with keras
under a TPU distribution strategy and use model.fit
instead of my custom training loop, I can get the loss metrics printed.
So there must be a way and any help in finding this is greatly appreciated :)