I'm trying to train a simple EfficientNet style model on some images. Training works fine on a CPU, but when I switch across to using a TPU I get the following error:
(0) Invalid argument: {{function_node __inference_train_function_38255}} Output shapes of then and else branches do not match: (s64[1,<=4]) vs. (s64[<=4])
[[{{node cond}}]]
[[TPUReplicate/_compile/_5430787790498024493/_4]]
[[tpu_compile_succeeded_assert/_6318656678166656164/_5/_289]](1) Invalid argument: {{function_node __inference_train_function_38255}} Output shapes of then and else branches do not match: (s64[1,<=4]) vs. (s64[<=4])
[[{{node cond}}]]
[[TPUReplicate/_compile/_5430787790498024493/_4]]
[[tpu_compile_succeeded_assert/_6318656678166656164/_5/_225]]
This error only occurs when I'm using a particular metric, Cohen's Kappa. If I remove this metric, the model trains fine.
I've tried to figure out the offending section in CohensKappa and narrowed it down to _update_confusion_matrix
- if I overload this and result
, the model trains fine.
When I start training, I see this log message:
TPU has inputs with dynamic shapes: [<tf.Tensor 'Const:0' shape=() dtype=int32>, <tf.Tensor 'cond_8/Identity:0' shape=(None, 456, 456, 3) dtype=float32>, <tf.Tensor 'cond_8/Identity_1:0' shape=(None,) dtype=int64>]
Which may be related, however given that the model trains fine when I leave out this metric and I still get that log, it might be a red herring.
Any suggestions on solutions, or how to debug this would be very helpful. Switching to eager execution mode isn't an option, as it all works fine on CPU.