I am encountering strange behavior when trying to evaluate the derivatives of a result obtained by sparse tensor operations. If I blow up all sparse inputs to dense before operating on them, the following code works as expected (first part of the following code), but it crashes with InvalidArgumentError
when I do the same with sparse tensors. In addition, I get while_loop
Warnings as below. As in the actual problem of course more operations and much bigger and more tensors are involved, I essentially have to collect the entries of c
in sparse mode. Can anyone make (more) sense of this behavior?
import tensorflow as tf
import numpy as np
a=tf.SparseTensor(indices=[[0,0],[1,1]],values=np.array([1,1],dtype=np.float32),dense_shape=(2,2))
b=tf.SparseTensor(indices=[[0,1],[1,0]],values=np.array([-1,-1],dtype=np.float32),dense_shape=(2,2))
#dense mode...
f1=tf.Variable([1,1],dtype=np.float32)
with tf.GradientTape() as gtape:
c=tf.sparse.to_dense(a)*f1[0]+tf.sparse.to_dense(b)*f1[1]
print(gtape.jacobian(c,f1)) #... works fine
#sparse mode...
f2=tf.Variable([1,1],dtype=np.float32)
with tf.GradientTape() as gtape:
c=tf.sparse.add(a*f2[0],b*f2[1],0)
c=tf.sparse.to_dense(c)
print(gtape.jacobian(c,f2)) #... InvalidArgumentError
#WARNING:tensorflow:Using a while_loop for converting SparseAddGrad
#WARNING:tensorflow:Using a while_loop for converting SparseTensorDenseAdd
#WARNING:tensorflow:Using a while_loop for converting SparseTensorDenseAdd
#---------------------------------------------------------------------------
#InvalidArgumentError Traceback (most recent call last)
#<ipython-input-10-d449761ef6b2> in <module>
# 12 c=tf.sparse.add(a*f2[0],b*f2[1],0)
# 13 c=tf.sparse.to_dense(c)
#---> 14 print(gtape.jacobian(c,f2)) #InvalidArgumentError
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\backprop.py in jacobian(self, target, sources, unconnected_gradients, parallel_iterations, experimental_use_pfor)
# 1187 try:
# 1188 output = pfor_ops.pfor(loop_fn, target_size,
#-> 1189 parallel_iterations=parallel_iterations)
# 1190 except ValueError as err:
# 1191 six.reraise(
#c:\program files\python37\lib\site-packages\tensorflow\python\ops\parallel_for\control_flow_ops.py in pfor(loop_fn, iters, fallback_to_while_loop, parallel_iterations)
# 203 def_function.run_functions_eagerly(False)
# 204 f = def_function.function(f)
#--> 205 outputs = f()
# 206 if functions_run_eagerly is not None:
# 207 def_function.run_functions_eagerly(functions_run_eagerly)
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
# 826 tracing_count = self.experimental_get_tracing_count()
# 827 with trace.Trace(self._name) as tm:
#--> 828 result = self._call(*args, **kwds)
# 829 compiler = "xla" if self._experimental_compile else "nonXla"
# 830 new_tracing_count = self.experimental_get_tracing_count()
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
# 893 # If we did not create any variables the trace we have is good enough.
# 894 return self._concrete_stateful_fn._call_flat(
#--> 895 filtered_flat_args, self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access
# 896
# 897 def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
# 1917 # No tape is watching; skip to running the function.
# 1918 return self._build_call_outputs(self._inference_function.call(
#-> 1919 ctx, args, cancellation_manager=cancellation_manager))
# 1920 forward_backward = self._select_forward_and_backward_functions(
# 1921 args,
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
# 558 inputs=args,
# 559 attrs=attrs,
#--> 560 ctx=ctx)
# 561 else:
# 562 outputs = execute.execute_with_cancellation(
#c:\program files\python37\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
# 58 ctx.ensure_initialized()
# 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
#---> 60 inputs, attrs, num_outputs)
# 61 except core._NotOkStatusException as e:
# 62 if name is not None:
#InvalidArgumentError: Only tensors with ranks between 1 and 5 are currently supported. Tensor rank: 0
# [[{{node gradient_tape/SparseTensorDenseAdd_1/pfor/while/body/_56/gradient_tape/SparseTensorDenseAdd_1/pfor/while/SparseTensorDenseAdd}}]] [Op:__inference_f_6235]
#Function call stack:
#f