I am sub-classing tensorflow.keras.Model
to implement a certain model. Expected behavior:
- Training (fitting) time: returns a list of tensors including the final output and auxiliary output;
- Inferring (predicting) time: returns a single output tensor.
And the code is:
class SomeModel(tensorflow.keras.Model):
# ......
def call(self, x, training=True):
# ......
return [aux1, aux2, net] if training else net
This is how i use it:
model=SomeModel(...)
model.compile(...,
loss=keras.losses.SparseCategoricalCrossentropy(),
loss_weights=[0.4, 0.4, 1],...)
# ......
model.fit(data, [labels, labels, labels])
And got:
AssertionError: in converted code:
ipython-input-33-862e679ab098:140 call *
`return [aux1, aux2, net] if training else net`
...\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt
Then the problem is that the if
statement is converted into the calculation graph and this would of course cause the problem. I found the whole stack trace is long and useless so it's not included here.
So, is there any way to make TensorFlow generate different graph based on training
or not?