5

Is it possible to get the following minimal example working with experimental_compile=True? I've seen some big speedups with this argument hence I am keen to figure out how to get it working. Thanks!

import tensorflow as tf

print(tf.__version__)
# ===> 2.2.0-dev20200409

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

for i, tensor in enumerate(ragged_tensor):
    print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
#  [5. 6. 7. 8. 9.]]

# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]

# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
#  [20. 21. 22. 23. 24.]]


@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


while_loop_fail()
# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
#    [[{{node while/RaggedGetItem/strided_slice_4}}]]
#    [[while]]
#   This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]
Jeff
  • 718
  • 8
  • 20

2 Answers2

4

There seems to be a lot of limitations about what XLA can do with ragged tensors. There are a couple of alternatives I can think of that could make your example work, but I don't know if they will we applicable to your real use case. On the one hand, you could sum over the ragged dimension(s) in advance, or even over all dimensions except the first one in your case. This however would need to be done outside of XLA, as it does not seem to be able to compile it:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Use the sums computed before
        return i + 1, running_total + ragged_sum[i]

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


result = while_loop_works()
print(result.numpy())
# 300.0

You can also just convert the ragged tensor into a regular tensor, which will pad it with zeros that wouldn't affect your sum. Again, this would currently need to be done out of XLA:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Reduce padded tensor
        return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


result = while_loop_works()
print(result.numpy())
# 300.0
jdehesa
  • 58,456
  • 7
  • 77
  • 121
  • Thanks for the reply, @jdehesa. I guess your answer is "no, it's not possible" as unfortunately my use case is more more complicated than just the `tf.reduce_sum` which I just added in as an example. – Jeff May 04 '20 at 12:14
  • @Jeff Well I can't say I'm an expert on XLA, but it seems it needs to know the size of `ragged_tensor[i]` at compile time, which obviously is not possible because it changes with `i`, so my first impression is it is not possible in this way, unless there is some XLA-specific trick/option/tool I'm not aware of. I imagined your case would be more complicated, if you can share either the whole thing, or a more representative snippet, maybe some other workaround could be worked out (maybe avoiding the while loop altogether...) – jdehesa May 04 '20 at 12:24
  • Unfortunately I can't share the code as it is work stuff :( I think it might be quite difficult to come up with a good example here but will try my best when I have time. – Jeff May 04 '20 at 14:19
1

For anyone having this sort of issue, I just noticed that on TensorFlow 2.5 this works (replacing experimental_compile with jit_compile):

import tensorflow as tf

print(tf.__version__)
# 2.5.0

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

for i, tensor in enumerate(ragged_tensor):
    print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
#  [5. 6. 7. 8. 9.]]

# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]

# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
#  [20. 21. 22. 23. 24.]]


@tf.function(autograph=False, jit_compile=True)
def while_loop_works():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


while_loop_works()
# 2021-06-28 13:18:19.253261: I tensorflow/compiler/jit/xla_compilation_cache.cc:337] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
# <tf.Tensor: shape=(), dtype=float32, numpy=300.0>

Jeff
  • 718
  • 8
  • 20