I have a 2D tensor in TensorFlow 2 (python). How can I pick-out and concatenate rows based on a ragged array of row indices and then pad shorter rows with zeros so that all rows end up with the same length?
Here is an example of what I have:
data = tf.constant([
[300, 301, 302],
[100, 101, 102],
[200, 201, 202],
[120, 121, 122],
[210, 211, 212],
[410, 411, 412],
[110, 111, 112],
[400, 401, 402],
], dtype=tf.float32)
row_ids = [ [ 1, 6, 3 ], [ 2, 4 ], [ 0 ], [ 7, 5] ]
And this is what I would like to get:
desired_result = tf.constant([
[ 100, 101, 102, 110, 111, 112, 120, 121, 122],
[ 200, 201, 202, 210, 211, 212, 0, 0, 0],
[ 300, 301, 302, 0, 0, 0, 0, 0, 0],
[ 400, 401, 402, 410, 411, 412, 0, 0, 0]
],
dtype=tf.float32
)
I have attempted to find a way with tf.RaggedTensor.from_value_rowids()
and tf.gather_nd()
with tf.concat()
but without any success.
I do need to backpropagate through this operation and, therefore, I need to stick to TensorFlow 2 operations.
Any suggestions would be greatly appreciated! Thanks!