I'm implementing a custom layer in tensorflow 2.x . My requirement is such that, the program should check a condition before returning the output.
class SimpleRNN_cell(tf.keras.layers.Layer):
def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True):
super(SimpleRNN_cell, self).__init__()
pass
def call(self, X, hidden_state, return_state=True):
y = tf.constant(5)
if return_state == True:
return y, self.h
else:
return y
My question is: should I continue using the present code (assuming that the tape.gradient(Loss, self.trainable_weights)
will work fine) or should I use tf.cond()
.
Also, if possible please explain where to use tf.cond()
and where not to. I haven't found much content on this topic.