I am trying to train transformer encoder (from here - https://www.tensorflow.org/tutorials/text/transformer) on TPU:
def test():
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
with tf.GradientTape() as tape:
predictions = transf(inp,
True,
None,
None,
None)
loss = loss_function(tar, predictions) # <- error is here
# I use SparseCategoricalCrossentropy()
vocabsize=1000
transf = Transformer(num_layers, d_model, num_heads, dff,
vocabsize, vocabsize,
pe_input=vocabsize,
pe_target=vocabsize,
rate=dropout_rate)
for iter in range(1,75000):
print(iter)
inp=np.random.randint(vocabsize, size=(5,11))
tar=np.random.randint(vocabsize, size=(5,11))
train_step(inp,tar)
It works on CPU. But after ~100 iterations on TPU I get an error when calling loss_function (marked above):
InvalidArgumentError:
Function invoked by the following node is not compilable: {{node __inference_train_step_4179}} = __inference_train_step_4179[_XlaMustCompile=true, config_proto="\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001", executor_type=""](dummy_input, dummy_input, dummy_input, dummy_input...
Uncompilable nodes: sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Const:
unsupported op: Const op with type DT_STRING is not supported by XLA.
Stacktrace: Node: __inference_train_step_4179, function: Node: sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Const, function: __inference_train_step_4179 ...
As far as I understand - the error is caused by assertion within loss function which is not supported by Xla. What can I do here ?