import tensorflow as tf
with tf.GradientTape() as g:
x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
z = tf.constant([[5., 6., 3.], [7., 8., 4.]], dtype=tf.float32)
g.watch(x)
g.watch(z)
y1 = x * x
y2 = z * z
y = tf.concat([y1, y2], axis=1)
batch_jacobian = g.batch_jacobian(y, z[:,0:1],unconnected_gradients=tf.UnconnectedGradients.NONE)
The output is zero. If the slicing breaks the graph, then I should get None instead of zero. I am attaching the output here in image form
batch_jacobian = [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]